Open frallebini opened 2 years ago
Hi @frallebini! The writeup in the paper is for the special case of an MLP with no bias terms -- the version in the code is just more general. The connection here is that there's a sum over all weight arrays that interact with that P_\ell
. Then for each one, we need to apply its relevant permutations on all other axis, take the Frobenius inner product with the reference model, and all those terms together. So A
represents that sum, each for loop iterations adds a single term in to the sum, get_permuted_param
applies the other (non-P_\ell
) permutations to w_b
, and the moveaxis
-reshape
-matmul corresponds to the Frobenius inner product with w_a
.
Thanks @samuela, I understand that the code is a generalization of the MLP with no bias case, but still:
moveaxis
-reshape
-@
operation corresponded to the Frobenius inner product with w_a
, wouldn't A
be a scalar?get_permuted_param
"skip" the non-P_\ell
permutations? Doesn't the except_axis
argument mean that, for example, if I want to permute rows, then I have to apply the permutation vector perm[p]
along the column dimension?If the moveaxis-reshape-@ operation corresponded to the Frobenius inner product with w_a, wouldn't A be a scalar?
Ack, you're right! I messed up: it's not actually a Frobenius inner product, just a regular matrix product. The moveaxis-reshape combo is necessary to flatten dimensions that we don't care about in the case of non-2d weight arrays.
How does get_permutedparam "skip" the non-P\ell permutations? Doesn't the except_axis argument mean that, for example, if I want to permute rows, then I have to apply the permutation vector perm[p] along the column dimension?
Yup, that's exactly what except_axis
is doing. But I think you may have it backwards -- except_axis
is excepting the P_\ell axis but applying all other fixed P's to all the other axes.
Ok, but let us consider the MLP-with-no bias case. The way the paper models weight matching as an LAP is
In other words, it computes A
as
(1)
What the code does, instead—if I understood correctly—is computing A
by
w_b
disregarding P_\ell
w_a
by itIn other words
(2)
I don't think (1) and (2) are the same thing though.
Hmm I think the error here is in the first line of (2): The shapes here don't line up since $W\ell^A$ has shape (n, *)
and $W{\ell+1}^A$ has shape (*, n)
. So adding those things together will result in a shape error if your layers have different widths.
I think tracing out the code for the MLP without bias terms case is a good idea. In that case we run through the for wk, axis in ps.perm_to_axes[p]:
loop two times: once for $W\ell$ and once for $W{\ell+1}$.
axis=0
since $W\ell$ has shape (n, *)
. Then, w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
will give us $W\ell^B P{\ell-1}^T$. In other words, $W\ell^B$ but with the other permutations -- $P{\ell-1}$ in this case -- applied to the other axes. jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
will be a no-op since axis = 0
. And w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
will also be a no-op. So, w_a @ w_b.T
is $W\ell^A (W\ell^B P_{\ell-1}^T)^T$ matches up with the first term in the sum.axis = 1
since $W{\ell+1}$ has shape (*, n)
. Then, w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
will give us $P{\ell+1} W{\ell+1}^B$. In other words, $W{\ell+1}^B$ but with the other permutations -- $P{\ell+1}$ in this case -- applied to the other axes. jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
will result in a transpose, aka $(W_{\ell+1}^A)^T$. And w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
will also result in a transpose, aka $(W{\ell+1}^B)^T P{\ell+1}^T$. So, w_a @ w_b.T
matches up with the second term in the sum.Ok, the role of moveaxis
is clear, and the computation matches the formula in the paper for an MLP with no biases.
On the other hand, the reshape((n, -1))
(extending the reasoning to the presence of biases):
n
is either the number of rows of $W\ell$ or it is the number of columns of $W{\ell+1}$, which however has already been transposed by the moveaxis
.(n,)
bias vectors into (n, 1)
vectors so that w_a @ w_b.T
is a (n, n)
matrix which can be added to A
.Right?
That's correct! In addition, it's necessary when dealing weight arrays of higher shapes as well, eg in a convolutional layer where the weights have shape (w, h, channel_in, channel_out)
.
Hi, I read the code and I really did not understand the following snippet. Because It relates to the weight matching algorithm, so I post here. In the line 199 weight_matching.py:
perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}
According to the above line, if W\ell has shape [m, n] (m is output feature dim, n is input feature dim) in the Dense layer, then the shape of the permutation matrix P\ell will be [n, n]. But when I read the paper, I think it should be [m, m].
Sorry for the silly question, but might you explain? @samuela @frallebini
Thank you!
Hi @LeCongThuong, ps.perm_to_axes
is a dict of form PermutationId => [(ParamId, Axis), ...]
where in this case PermutationId
s are strings, ParamId
s are also strings, and Axis
s are integers. So for example in an MLP (without bias and assuming that weights have shape [out_dim, in_dim]) terms this dict would look something like
{ "P_5": [("Dense_5/kernel", 0), ("Dense_6/kernel", 1)], ... }
Therefore, axes[0][0]
will be something like "Dense_0/kernel"
and axes[0][1]
will be 0
. HTH!
Thank you so much for replying @samuela!
I tried to understand ps.perm_to_axes
and got the meaning of Axis
. Axis, from what I got from your comment, it will let us know to permute W_b to another axis than "Axis''. Following your above example, I think it should be
{ "P_5": [("Dense_5/kernel", 1), ("Dense_6/kernel", 0)], ... }
From that axes[0][1] will be 1, thus the shape of P_l will be [n, n].
Thank you again for replying to my question.
Hi, I read the paper and I am having a really hard time reconciling the formula
with the actual computation of the cost matrix for the LAP in
weight_matching.py
, namelyAre you following a different mathematical derivation or am I missing something?