samuela / git-re-basin

Code release for "Git Re-Basin: Merging Models modulo Permutation Symmetries"
https://arxiv.org/abs/2209.04836
MIT License
470 stars 40 forks source link

Cost Matrix Computation in Weight Matching #4

Open frallebini opened 2 years ago

frallebini commented 2 years ago

Hi, I read the paper and I am having a really hard time reconciling the formula

weight_matching

with the actual computation of the cost matrix for the LAP in weight_matching.py, namely

A = jnp.zeros((n, n))
for wk, axis in ps.perm_to_axes[p]:
  w_a = params_a[wk]
  w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
  w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
  w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
  A += w_a @ w_b.T

Are you following a different mathematical derivation or am I missing something?

samuela commented 2 years ago

Hi @frallebini! The writeup in the paper is for the special case of an MLP with no bias terms -- the version in the code is just more general. The connection here is that there's a sum over all weight arrays that interact with that P_\ell. Then for each one, we need to apply its relevant permutations on all other axis, take the Frobenius inner product with the reference model, and all those terms together. So A represents that sum, each for loop iterations adds a single term in to the sum, get_permuted_param applies the other (non-P_\ell) permutations to w_b, and the moveaxis-reshape-matmul corresponds to the Frobenius inner product with w_a.

frallebini commented 2 years ago

Thanks @samuela, I understand that the code is a generalization of the MLP with no bias case, but still:

  1. If the moveaxis-reshape-@ operation corresponded to the Frobenius inner product with w_a, wouldn't A be a scalar?
  2. How does get_permuted_param "skip" the non-P_\ell permutations? Doesn't the except_axis argument mean that, for example, if I want to permute rows, then I have to apply the permutation vector perm[p] along the column dimension?
samuela commented 2 years ago

If the moveaxis-reshape-@ operation corresponded to the Frobenius inner product with w_a, wouldn't A be a scalar?

Ack, you're right! I messed up: it's not actually a Frobenius inner product, just a regular matrix product. The moveaxis-reshape combo is necessary to flatten dimensions that we don't care about in the case of non-2d weight arrays.

How does get_permutedparam "skip" the non-P\ell permutations? Doesn't the except_axis argument mean that, for example, if I want to permute rows, then I have to apply the permutation vector perm[p] along the column dimension?

Yup, that's exactly what except_axis is doing. But I think you may have it backwards -- except_axis is excepting the P_\ell axis but applying all other fixed P's to all the other axes.

frallebini commented 2 years ago

Ok, but let us consider the MLP-with-no bias case. The way the paper models weight matching as an LAP is

weight_matching_complete

In other words, it computes A as

paper (1)

What the code does, instead—if I understood correctly—is computing A by

  1. Permuting w_b disregarding P_\ell
  2. Transposing it
  3. Multiplying w_a by it

In other words

code (2)

I don't think (1) and (2) are the same thing though.

samuela commented 2 years ago

Hmm I think the error here is in the first line of (2): The shapes here don't line up since $W\ell^A$ has shape (n, *) and $W{\ell+1}^A$ has shape (*, n). So adding those things together will result in a shape error if your layers have different widths.

I think tracing out the code for the MLP without bias terms case is a good idea. In that case we run through the for wk, axis in ps.perm_to_axes[p]: loop two times: once for $W\ell$ and once for $W{\ell+1}$.

frallebini commented 2 years ago

Ok, the role of moveaxis is clear, and the computation matches the formula in the paper for an MLP with no biases.

On the other hand, the reshape((n, -1)) (extending the reasoning to the presence of biases):

Right?

samuela commented 2 years ago

That's correct! In addition, it's necessary when dealing weight arrays of higher shapes as well, eg in a convolutional layer where the weights have shape (w, h, channel_in, channel_out).

LeCongThuong commented 1 year ago

Hi, I read the code and I really did not understand the following snippet. Because It relates to the weight matching algorithm, so I post here. In the line 199 weight_matching.py:

perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

According to the above line, if W\ell has shape [m, n] (m is output feature dim, n is input feature dim) in the Dense layer, then the shape of the permutation matrix P\ell will be [n, n]. But when I read the paper, I think it should be [m, m].

Sorry for the silly question, but might you explain? @samuela @frallebini

Thank you!

samuela commented 1 year ago

Hi @LeCongThuong, ps.perm_to_axes is a dict of form PermutationId => [(ParamId, Axis), ...] where in this case PermutationIds are strings, ParamIds are also strings, and Axiss are integers. So for example in an MLP (without bias and assuming that weights have shape [out_dim, in_dim]) terms this dict would look something like

{ "P_5": [("Dense_5/kernel", 0), ("Dense_6/kernel", 1)], ... }

Therefore, axes[0][0] will be something like "Dense_0/kernel" and axes[0][1] will be 0. HTH!

LeCongThuong commented 1 year ago

Thank you so much for replying @samuela!

I tried to understand ps.perm_to_axes and got the meaning of Axis. Axis, from what I got from your comment, it will let us know to permute W_b to another axis than "Axis''. Following your above example, I think it should be

{ "P_5": [("Dense_5/kernel", 1), ("Dense_6/kernel", 0)], ... }

From that axes[0][1] will be 1, thus the shape of P_l will be [n, n].

Thank you again for replying to my question.