pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
417 stars 20 forks source link

Optimal ordering with block mask #56

Open francois-rozet opened 3 days ago

francois-rozet commented 3 days ago

From my understanding, flex attention (using block_mask) gets faster when the number of empty blocks is larger. If the inputs (Q, K, V) do not represent sequences, but graphs with local connectivity (e.g. pixels in an image) the ordering of the elements has a huge impact on the number of empty blocks.

It would be very useful to add helpers to find optimal, or simply better, orderings given a mask. For example, for images, it is likely better to order the pixels by small patch (close to the attention window size), rather than the standard row-by-row order.

Note that this is related to the minimum degree algorithm.

Chillee commented 1 day ago

Yeah, this is a pretty fun idea :) I had previously played around with an idea like this using a permute_transform like so

def permute_mod(mod, permutation):
  def new_mod(b, h, q, kv):
    q_idx = permutation[q]
    kv_idx = permutation[kv]
    return mod(b, h, q_idx, kv_idx)
  return new_mod

and so, this allows you to transform any existing mask_mod to one that operates on a permuted input. Unfortunately, this does require you to do a bunch of additional memory accesses, so might not be worth it unless you get way more sparsity. But I had some good successes in certain cases with a hilbert curve.

For 2d images, uwu (on Discord) suggested trying a Morton curve, which could be a good alternative, since it's cheap to "compute" :)

francois-rozet commented 1 day ago

I think it is worth it if you can do the permutation once before a series of attention operations. That is pretty much the case in vision transformers with local windows.

I also tried the Hilbert and Moore curves, but I haven't conducted a proper benchmark.

Chillee commented 1 day ago

The issue isn't necessarily that permuting the tokens itself is expensive, but rather that after the permutation you need to load the permutation index into the "inner loop" of the attention, which does offset some of the sparsity gains you can get.

Why is why Morton curves were an interesting suggestion to me, since I think they're fairly cheaply computable "within" the kernel itself.

francois-rozet commented 1 day ago

I don't think you need to load the permutation index if you compute the BlockMask once.

Basically my idea was to find a permutation to minimize the number of (non-empty) blocks in a BlockMask. Then you can reuse the same block mask again and again.

Chillee commented 1 day ago

If you can guarantee that all of your non-empty blocks are "full" (i.e. non-masked at all), then you don't need to load the permutation index for those blocks.

However, for the partially-masked blocks, you still need to load permutation index to compute the mask for those blocks. For example, this is NATTEN with a hilbert curve.

image

francois-rozet commented 1 day ago

I don't see where the permutation indices appear anymore after the BlockMask has been created with the permuted mask_mod. If both the sequence and the block mask are permuted, there is no permutation happening anymore. There is indexing with respect to column indices in the block mask, but no "permutation".

Chillee commented 1 day ago

Yes, that's what I mean. You must load from your column indices (which represent a permutation) in your inner loop.

francois-rozet commented 1 day ago

I think we are speaking of the same think in different terms, but I don't see how the column indices represent permutations. They are ordered (which allows faster access than random indexing) and target a subset of the full block. I agree that the subset is determined by the original permutation, but the indexing operation does not involve a permutation.

Anyway, your NATTEN + Hilbert curve seems much more efficient than NATTEN alone! Do you still have the code to generate the permutation? I used a random Python library previously.