Descriptive Proof of Factorization with Factorization Machines
The Paper
Steffen Rendle introduced Factorization Machines a while ago in 2010, but I have only more recently come across it.
There is only a very short proof provided of Lemma 3.1 (page 3): “The model equation of a factorization machine can be computed in linear time \(O(k * n)\).”
I decided to derive the proof myself and found some things interesting along the way. I wrote down an earful hoping that someone else finds it helpful!
The Proof
Let \(f\) be a function with commutative operator \(\circ\) (such as multiplication, dot product, Hadamard product, or simultaneously diagonalizable matrix multiplication ^{1}) that follows the decomposition:
\[f(i, j) = u(i) \circ s(j) = u(j) \circ s(i) = f(j, i)\]There are many cases in computer science where we want to compute something in a pairwise fashion between all \(u(x)\) and \(s(x)\). One such case is linear combination of pairwise effects for the function \(f(i, j)\). A naive description of the component wise operations that are the “product” between two elements can be computed as follows: \(\sum\limits_{i=1}^n \sum\limits_{j=1}^n f(i, j)\)
A standard calculation requires \(n^2\) operations. However, if \(\circ\) is also distributive over summation, we can expand and factorize this computation in \(O(n)\) operations. Let’s begin! $$ \sum\limits_{i=1}^n \sum\limits_{j=1}^n f(i, j) = u(0) \circ \left(\sum\limits_{i=1}^n s(i) \right)
 u(1) \circ \left(\sum\limits_{i=1}^n s(i) \right)
 \ldots + u(n) \circ \left(\sum\limits_{i=1}^n s(i) \right) \(\) = \left(u(0) + u(1) + \ldots + u(n) \right) \circ \left(\sum\limits_{i=1}^n s(i) \right) \(\) = \left(\sum\limits_{i=1}^n u(i) \right) \circ \left(\sum\limits_{i=1}^n s(i) \right) $$
This is really, really awesome. We have reduced from \(O(n^2)\) \(+\) operations and \(O(n^2)\) \(\circ\) operations to \(O(n)\) \(+\) operations and only one lonely \(\circ\) operation. This has reduced an unnecessary loop over \(j\), illustrating that this problem is actually less complex than it originally appeared ^{2}.
Is this intuitive? I think so. Remember in elementary school, where it was way more fun to compute something like \((1 + 2) * (3 + 4)\) than \(1 * 3 + 1 * 4 + 2 * 3 + 2 * 4\)? This is a similar thing, where you win by factorizing. In this case, we can further cheat becomes it becomes clear when you write the expansion that \(j\) is not a necessary variable. This is a case where pairwise interactions can factor out due to the combination of commutative and distributive properties of \(\circ\) in such a beautiful way so we don’t care what the value of \(j\) is. Amazing!
For a bit of a brain tease, let \(+ = \min\), \(\circ = \max\). This would be like having each \(u(i)\) \(s(j)\) pair joust for highest value, and then taking the smallest of these values (perhaps a quasiwinner take all?): \(\min_{i, j}(\max(u(i), s(j)))\) Due to our factorization proof, it would mean it is equivalent to computing: \(\max(\min_i(u(i)), \min_i(s(i)))\) I find this somewhat intuitive; the largest \(\max(u(i), s(j))\) simply depends on either the row \(u(i)\) or column \(s(j)\) value, not the combination.
What if we don’t care about self similarity \(f(i, i)\)? We can have a smart factorization that significantly reduces the number of computations when self similarity \(f(i, i)\) is unimportant.
Due to the commutative property of \(\circ\), the operations are symmetric. The sum can be thought of as summing over all entries for a matrix of the results \(f(i, j)\):
\[M = \begin{pmatrix} f(1,1) & f(2, 1) & \ldots & f(n, 1)\\ f(1, 2) & f(2, 2) & \ldots & f(n, 2)\\ \vdots & \vdots & \ddots & \vdots\\ f(1, n) & f(2, n) &\ldots & f(n, n) \end{pmatrix}\]where
\[\sum M = \sum\limits_{i=1}^n \sum\limits_{j=1}^n f(i, j)\]Due to the triangular nature of this matrix, we can separate into upper triangular, lower triangular, and diagonal matrices:
\(M = U + L + D\) \(U= \begin{pmatrix} 0 & f(2, 1) & \ldots & f(n, 1)\\ 0 & 0 & \ldots & f(n, 2)\\ \vdots & \vdots & \ddots & \vdots\\ 0 & 0 &\ldots & 0 \end{pmatrix} , L= \begin{pmatrix} 0 & 0 & \ldots & 0\\ f(1, 2) & 0 & \ldots & 0\\ \vdots & \vdots & \ddots & \vdots\\ f(1, n) & f(2, n) &\ldots & 0 \end{pmatrix}\) \(D= \begin{pmatrix} f(1,1) & 0 & \ldots & 0\\ 0 & f(2, 2) & \ldots & 0\\ \vdots & \vdots & \ddots & \vdots\\ 0 & 0 &\ldots & f(n, n) \end{pmatrix}\)
Due to the commutative nature of \(\circ\): \(\sum U = \sum L\)
We can then decompose \(\sum M = 2 \sum U + \sum D\)
Or rearranged: \(\sum U = \frac{\sum M  \sum D}{2}\)
Therefore, if all you care about is just the upper triangular: \(\sum\limits_{i=1}^n \sum\limits_{j > i}^n f(i, j) = \sum U = \frac{\sum M  \sum D}{2}\)
Lets go back to a more verbose, lower level representation expanding the sums \(\frac{\sum M  \sum D}{2}\) as follows: \(\frac{1}{2}\sum\limits_{i=1}^n \sum\limits_{j=1}^n f(i, j)  \frac{1}{2} \sum\limits_{i=1}^n f(i, i)\)
We can simplify the values of the diagonal in \(D\): \(\sum\limits_{i=1}^n f(i, i) = \sum\limits_{i=1}^n u(i) \circ s(i)\)
Putting it all together, we get can compute the upper triangular pairwise interactions as follows: \(\sum\limits_{i=1}^n \sum\limits_{j > i}^n f(i, j) = \frac{1}{2}\left(\sum\limits_{i=1}^n u(i) \right) \circ \left(\sum\limits_{i=1}^n s(i) \right)  \frac{1}{2}\sum\limits_{i=1}^n u(i) \circ s(i)\)
As an example, when \(\circ\) is \(*\) and \(s(x) = u(x)\), this looks like the following: \(\sum\limits_{i=1}^n \sum\limits_{j > i}^n f(i, j) = \frac{1}{2}\left(\sum\limits_{i=1}^n u(i) \right)^2  \frac{1}{2}\sum\limits_{i=1}^n u(i)^2\)
There is something very beautiful about this transformation. It says that summing the upper triangular elements computed by \(f(i, j)\) can be decomposed into a one dimensional function that only iterates over the number of components \(n\).
Why is this useful?
Basically factorization reduces redundant computation.
When \(\circ\) performs some really expensive operation like a dot product upon all pairwise combinations of elements, computing \(\sum U\) can be simplified to \(O(n * k)\), where \(k\) is the length of the vector.
This is exactly the case used in Factorization Machines, where we create a latent space from an input \(n\), \(k << n\). When modeling real world data, \(k\) probably has some sublinear dependency on \(n\), perhaps \(k \approx \log_2(n)\). That’s much better than the naive \(O((n * k)^2)\)!
In addition, this simplification is used as a tool to create other proofs, like showing that Spearman’s \(\rho\) coefficient is a particular case of the general correlation coefficient.

If anyone knows real world cases for simultaneously diagonalizable matrices, that would be pretty neat. ↩

Philosophically, I think the Kolmogorov complexity of algorithms is related to the depth of conditional statements. I don’t have a strong proof of this but it feels like that each conditional is another layer in a state diagram for a program, and the number of states of the state diagram being some measure of Kolmogorov complexity. Here we have reduced a layer of conditional statements that would occur with the loop operation, meaning this has lower Kolmogorov complexity. ↩