facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.23k stars 578 forks source link

Splitting an BlockCausalDiagonalMask on query dimension #995

Closed thibautlavril closed 5 months ago

thibautlavril commented 5 months ago

❓ Questions and Help

Hello!

I wanted to know if it was possible to "split" a BlockCausalDiagonalMask (or to create the 3 equivalent AttentionBias) with this pattern:

Screenshot 2024-03-12 at 15 57 10

(the goal is to split the queries into q1, q2, q3 and keep the same kv) I had a look at BlockDiagonalCausalWithOffsetPaddedKeysMask but not sure I can use it is made for this purpose

Thank you very much!

danthe3rd commented 5 months ago

Hi, BlockDiagonalCausalWithOffsetPaddedKeysMask is made for inference (the keys are padded because they are on the KV-cache). For what you are asking, we don't have a nice way to support it at the moment. I see 2 ways:

thibautlavril commented 5 months ago

Thank you very much for your answer!

Yes I think I am in the case 1. I am probably missing something but I dont know how I can generate sub bias.

For instance if the seqlens are [5, 4, 3] and I want to divide in 3x3 (for ring attention for instance), I am not sure on how to create a BLockDiagonalMask for the tile (i, j) .

Do you have a pointer to an example somewhere?

Thank you again!

danthe3rd commented 5 months ago

So you need to create 9 attention biases for 3x3? If the use-case is for RingAttention, it might not be efficient (as you will have imbalance between your ranks depending on the bias shape, some GPUs will be faster, and will wait for the slowest GPU).

So you want to do something like that:

image

For A, E and I it's just a regular BlockCausalDiagonalMask. The B, C and F regions are empty (assume causality + same size for the keys/queries sequences). For D, G and H it's a bit more tricky, because some keys don't attend to anything (which could cause nans in the BW pass), or some queries are not attended by anyone (which could cause nans in the attention output). It might be easier to slice the keys/queries, as I don't know if we have a bias that would work in that case

thibautlavril commented 5 months ago

Ok got it, will try to do custom handling for the moment!

By curiosity what approach will be taken for masking for in ring attention in xformers ?

(Feel free to close the issue)

Thanks again for your answers !

danthe3rd commented 5 months ago

By curiosity what approach will be taken for masking for in ring attention in xformers ?

We don't plan to support document masking in the first version (not released yet), just regular (causal) attention