pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
475 stars 23 forks source link

How to reason about efficiency of different score/mask mod functions #63

Open alex-hh opened 1 month ago

alex-hh commented 1 month ago

Hi,

The fact that it's possible to create arbitrary score mod / mask mod patterns is really powerful!

I'm wondering if there is any way to reason about the efficiency of different masking patterns (if this is a relevant consideration)?

For example, is a 'full' score_mod e.g. returning bias[b, h, i, j], where bias is some explicitly materialised attention bias tensor going to yield any efficiency gains over manually adding the bias to the attention logits? What are the relative efficiencies of e.g. structured and random sparsity patterns in mask_mod?

Thanks

Chillee commented 4 weeks ago

@alex-hh Generally speaking, the less memory you have to access from outside the kernel, the better. So loading from a full bias (i.e. size S^2) is going to be slower than loading from a 1d bias (i.e. size S), which is going to be slower than loading from.

For sparsity, FlexAttention is fundamentally block-sparse. So pure random sparsity is unlikely to help much.

alex-hh commented 4 weeks ago

Thanks for the reply! Got it re the memory.

Regarding block sparsity - does this mean that given a particular mask_mod pattern, there is potentially an optimal way of permuting the inputs before applying flex attention?

drisspg commented 3 weeks ago

Yeah indeed there is see, see this thread for some discussion: https://github.com/pytorch-labs/attention-gym/issues/56