Open sashaDoubov opened 1 year ago
Maybe the function you're looking for is block_diag_butterfly_project_einsum_rank. (you can see our tests here that the projection recovers the original factors) https://github.com/HazyResearch/fly/blob/cd624cffeffa7d1579336d26a776405bf0867f36/tests/ops/test_blockdiag_butterfly_einsum.py#L112
Thanks! just to make sure, the forward function called in monarch_linear.py: https://github.com/HazyResearch/fly/blob/cd624cffeffa7d1579336d26a776405bf0867f36/src/models/layers/blockdiag_butterfly_multiply.py#L63 is then equivalent to https://github.com/HazyResearch/fly/blob/cd624cffeffa7d1579336d26a776405bf0867f36/src/ops/blockdiag_butterfly_einsum.py#L89
and we can just use the block_diag_butterfly_project_einsum_rank
function for the projection step? I compared the two forward functions on a number of inputs, and they seemed equivalent to me, but just wanted to double check.
Another related question, I'm seeing relatively high projection error for arbitrary weight matrices.
ie. if I generate a standard normal matrix M
, with dimensions 1024 x 4096, project this into two monarch matrices with the function you suggested, then compute the overall projected matrix \tilde{M}
, I get a max element-wise difference of ~4. Is this expected? I'm finding that fine-tuning dense -> sparse fine-tuning is not performing well due to this projection error.
I'm wondering whether I'm using the function suggested correctly.
I've shown this in the colab here: https://colab.research.google.com/drive/18uQy0nWP-oH0bXcViwipzxsA-5MpfMpk?usp=sharing
Hi, I had a question regarding the projection step from dense -> monarch butterfly matrices, in the general case, where we have rectangular or nblocks != sqrt(m) of the matrix, as I'm having trouble finding the code for this and extending existing implementations for this case.
I found multiple projection files/functions: blockdiag_butterfly_projection.py and blockdiag_butterfly_einsum.py.
As an example, if I use code roughly as follows:
The resulting shape of the output will be
8x8
, which is essentially the full matrix used for transformation.However, if I use the projection function (which is meant for square matrices) from blockdiag_butterfly_projection.py to try and recover the butterfly matrices from this matrix, I run into the issue that it expects the matrix to decompose as follows
M_permuted_batched = rearrange(M, '(p k) (r s) -> k r p s', k=sizes[1], r=sizes[0])
, while in our case:r = 4 and s = 4
, making it incompatible with the matrix dimensions.Meanwhile, the einsum functions in blockdiag_butterfly_einsum.py gave different results from the original
blockdiag_butterfly_multiply
(comparing the forward multiplication step not the projection step). (see this colab)In the paper, I did see the original derivation for algorithm 1: but I was unclear on how to actually perform the decomposition step when we can't decompose the tensor into an
m x m x m x m
shape.