HazyResearch / fly

Apache License 2.0
194 stars 22 forks source link

Monarch in pure pytorch. [Yes, it's also a hypercube] #8

Open justheuristic opened 2 years ago

justheuristic commented 2 years ago

Hi, @tridao . Yes, it's me again. First, thanks for the monarch paper, it was quite a read ;)

I've been looking through the repo, but could not find the official monarch implementation yet. So i've built an unofficial one :) https://gist.github.com/justheuristic/499ec116f1f353dfd3314de87f310f80 (warning: pure torch, please don't use for speed evaluation)

And yes, it's also a hypercube. Lemme explain)

Imagine a small monarch layer with 4 input units and 4 output units.

Here's what the paper says should happen: image

Consider a matrix with N=4 input units, and hence, M=2. The permutation layer for 4 units is defined as __x_new = [x[0], x[2], x[1], x[3]]__.

Hence, the first batch matrix multiplication has two blocks: [x[0], x[2]] and [x[1], x[3]]. And the second matrix multiplication is unpermuted, the blocks are: [x[0], x[1]] and [x[2], x[3]].

Now let's naively reshape this tensor into a 2x2 square:

x[0] --- x[1]
  |        |
x[2] --- x[3]

As you can see, the two batch matrix products go over the square's columns (L matrix) and rows (R matrix).

This intuition holds for any valid N, M: for instance, N=1024, M=32 results in a 32-by-32 square lattice, and the column-to-row order stays the same.

This leads to a few obvious considerations:

On adding more dimensions: consider a GPT-3 layer has 12288 units. We can view this as a 3d lattice of shape [16, 24, 32], since 16 24 32 = 12288

Using the code above, you can define this specific grid as follows: image

This, in turn, raises several questions:

  1. On memory requirements of Monarch: when done naively, Monarch requires storing 1 additional tensor of activations for backprop for every additional dimensionality -- or recomputing them due to gradient checkpointing. Is there any more efficient strategy for backprop through Monarch?
  2. On relation to tensor decompositions: when viewed from this angle, Monarch sounds vaguely related (though not equivalent) to some popular tensor decompositions, such as TensorTrain or TensorRing. Is Monarch universally better or are there special cases where I should use either one?

p.s. the perspective from this question is not my own, we stumbled into it in discussions with @ostroumova-la , @TimDettmers , @KhrulkovV

tridao commented 2 years ago

Hi @justheuristic, this is a very insightful observation! I believe the tensor perspective is one very fruitful way to view butterfly / Monarch. This perspective is the basis of our Monarch projection algorithm in the paper. [As an aside, this tensor perspective of the FFT (which butterfly / Monarch generalize) can be traced back to viewing the FFT as message passing on a junction tree. Here's a reference.]

Your implementation is quite close to ours. Monarch multiply (both square and rectangular) can be implemented in 2-3 lines of Pytorch with the help of torch.einsum. The slightly faster version materializes one additional array of activation as you observed.

  1. Adding more dimensions: that's a great idea! I particularly like the einsum way to implement Monarch, and you can generalize that to more dimensions easily.
  2. Memory: we typically just store an additional array of activation. One can certainly use gradient checkpointing to trade memory/compute like you have done. Recomputation is quite fast since we're multiplying by a block-diagonal matrix.
  3. Relation to tensor decomposition: Monarch matrix can be viewed as a tensor as you observed. It was designed for hardware efficiency (hence the use of block-diagonal matrices) and to capture many fast transforms (e.g., Fourier). Some applications do well with the "fourier" inductive bias (e.g., MRI reconstruction and PDE solving in the paper), and so Monarch would be more well-suited compared to other tensor methods.
justheuristic commented 2 years ago

Thank you kindly for the response (and references)! I somehow did not make the connection to the message-passing view of FFT

Also, thank you for sharing the implementation, I'm currently running experiments in a memory-constrained setup and the re-materialization trick from BlockdiagButterflyMultiply was way more efficient than naive grad checkpointing.