HazyResearch / fly

Apache License 2.0
180 stars 21 forks source link

Monarch Projection Step for rectangular blocks and where the number of blocks != sqrt(input dimension) #10

Open sashaDoubov opened 1 year ago

sashaDoubov commented 1 year ago

Hi, I had a question regarding the projection step from dense -> monarch butterfly matrices, in the general case, where we have rectangular or nblocks != sqrt(m) of the matrix, as I'm having trouble finding the code for this and extending existing implementations for this case.

I found multiple projection files/functions: blockdiag_butterfly_projection.py and blockdiag_butterfly_einsum.py.

As an example, if I use code roughly as follows:

from src.models.layers.blockdiag_butterfly_multiply import blockdiag_butterfly_multiply

x = np.eye(8)
nblocks = 2
bfly_dim = 4

w1_bfly = np.random.normal((nblocks, bfly_dim, bfly_dim))
w2_bfly = np.random.normal((nblocks, bfly_dim, bfly_dim))

bfly_matrix = blockdiag_butterfly_multiply(x, w1_bfly, w2_bfly)

print(bfly_matrix.shape)

The resulting shape of the output will be 8x8, which is essentially the full matrix used for transformation.

However, if I use the projection function (which is meant for square matrices) from blockdiag_butterfly_projection.py to try and recover the butterfly matrices from this matrix, I run into the issue that it expects the matrix to decompose as follows M_permuted_batched = rearrange(M, '(p k) (r s) -> k r p s', k=sizes[1], r=sizes[0]), while in our case: r = 4 and s = 4, making it incompatible with the matrix dimensions.

Meanwhile, the einsum functions in blockdiag_butterfly_einsum.py gave different results from the original blockdiag_butterfly_multiply (comparing the forward multiplication step not the projection step). (see this colab)

In the paper, I did see the original derivation for algorithm 1: image but I was unclear on how to actually perform the decomposition step when we can't decompose the tensor into an m x m x m x m shape.

tridao commented 1 year ago

Maybe the function you're looking for is block_diag_butterfly_project_einsum_rank. (you can see our tests here that the projection recovers the original factors) https://github.com/HazyResearch/fly/blob/cd624cffeffa7d1579336d26a776405bf0867f36/tests/ops/test_blockdiag_butterfly_einsum.py#L112

sashaDoubov commented 1 year ago

Thanks! just to make sure, the forward function called in monarch_linear.py: https://github.com/HazyResearch/fly/blob/cd624cffeffa7d1579336d26a776405bf0867f36/src/models/layers/blockdiag_butterfly_multiply.py#L63 is then equivalent to https://github.com/HazyResearch/fly/blob/cd624cffeffa7d1579336d26a776405bf0867f36/src/ops/blockdiag_butterfly_einsum.py#L89

and we can just use the block_diag_butterfly_project_einsum_rank function for the projection step? I compared the two forward functions on a number of inputs, and they seemed equivalent to me, but just wanted to double check.

sashaDoubov commented 1 year ago

Another related question, I'm seeing relatively high projection error for arbitrary weight matrices.

ie. if I generate a standard normal matrix M, with dimensions 1024 x 4096, project this into two monarch matrices with the function you suggested, then compute the overall projected matrix \tilde{M}, I get a max element-wise difference of ~4. Is this expected? I'm finding that fine-tuning dense -> sparse fine-tuning is not performing well due to this projection error.

I'm wondering whether I'm using the function suggested correctly.

I've shown this in the colab here: https://colab.research.google.com/drive/18uQy0nWP-oH0bXcViwipzxsA-5MpfMpk?usp=sharing