pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
490 stars 24 forks source link

Support varied input sequence lengths with a fixed block mask #31

Open tilmto opened 3 months ago

tilmto commented 3 months ago

Thanks for the great repo!

When using a custom-defined attention mask pattern (e.g., the A-shape mask in this work), I noticed that when the input length (e.g., 512) is shorter than the length of the predefined block mask (e.g., 1024) in mask_mod, the generation results may not be correct, even though the attention pattern of the former is a truncated version of the latter.

Therefore, I wonder if FlexAttention generally supports varying input sequence lengths under a fixed block mask, and how it handles this situation?

drisspg commented 3 months ago

In general no, the current blessed solution is to call create_block_mask with the new shapes. It is possible to do the slicing of the inner tensors today. The description of this structure can be found here: https://github.com/pytorch/pytorch/blob/44dadf25065c73bd1370258e7fb1b421cee4283a/torch/nn/attention/flex_attention.py#L192

tilmto commented 3 months ago

Thanks for the prompt response! So, can I understand it like this: if we need to perform evaluations on common LM benchmarks, which often contain questions of varying lengths, we need to create the block mask on the fly for each input (ideally with _compile=True to speed up this process)?

drisspg commented 3 months ago

yup thats the best approach, with _compile=True the cost should be relatively low compared to actual compute and this cost gets ammortized over all attention calls throughout the model

tilmto commented 3 months ago

Got it! The last question is that I find that sometimes setting _compile=True leads to errors related to insufficient cache sizes. This often happens with models that have many full attentions, but when replacing them with sliding window attentions, everything works well. Are there any workarounds for this?

drisspg commented 3 months ago

hmmm this is likely a dynamic shapes thing, @Chillee