Open justheuristic opened 2 years ago
Hi @justheuristic, this is a very insightful observation! I believe the tensor perspective is one very fruitful way to view butterfly / Monarch. This perspective is the basis of our Monarch projection algorithm in the paper. [As an aside, this tensor perspective of the FFT (which butterfly / Monarch generalize) can be traced back to viewing the FFT as message passing on a junction tree. Here's a reference.]
Your implementation is quite close to ours. Monarch multiply (both square and rectangular) can be implemented in 2-3 lines of Pytorch with the help of torch.einsum. The slightly faster version materializes one additional array of activation as you observed.
Thank you kindly for the response (and references)! I somehow did not make the connection to the message-passing view of FFT
Also, thank you for sharing the implementation, I'm currently running experiments in a memory-constrained setup and the re-materialization trick from BlockdiagButterflyMultiply was way more efficient than naive grad checkpointing.
Hi, @tridao . Yes, it's me again. First, thanks for the monarch paper, it was quite a read ;)
I've been looking through the repo, but could not find the official monarch implementation yet. So i've built an unofficial one :) https://gist.github.com/justheuristic/499ec116f1f353dfd3314de87f310f80 (warning: pure torch, please don't use for speed evaluation)
And yes, it's also a hypercube. Lemme explain)
Imagine a small monarch layer with 4 input units and 4 output units.
Here's what the paper says should happen:
Consider a matrix with N=4 input units, and hence, M=2. The permutation layer for 4 units is defined as __x_new = [x[0], x[2], x[1], x[3]]__.
Hence, the first batch matrix multiplication has two blocks: [x[0], x[2]] and [x[1], x[3]]. And the second matrix multiplication is unpermuted, the blocks are: [x[0], x[1]] and [x[2], x[3]].
Now let's naively reshape this tensor into a 2x2 square:
As you can see, the two batch matrix products go over the square's columns (L matrix) and rows (R matrix).
This intuition holds for any valid N, M: for instance, N=1024, M=32 results in a 32-by-32 square lattice, and the column-to-row order stays the same.
This leads to a few obvious considerations:
On adding more dimensions: consider a GPT-3 layer has 12288 units. We can view this as a 3d lattice of shape [16, 24, 32], since 16 24 32 = 12288
Using the code above, you can define this specific grid as follows:
This, in turn, raises several questions:
p.s. the perspective from this question is not my own, we stumbled into it in discussions with @ostroumova-la , @TimDettmers , @KhrulkovV