Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.29k stars 1.2k forks source link

Variable memory allocation with varlen kernels #1011

Open CodeCreator opened 2 months ago

CodeCreator commented 2 months ago

Hey!

I'm a big fan of the flash attention varlen kernels, and they are fantastic for saving the memory & compute of pad tokens.

When training with fixed batches of N tokens, I've noticed that the memory will vary substantially depending on cu_seqlens and max_seqlen. I suspect this is due to the allocation of the softmax_lse in a padded format ([num_seqs, num_heads, max_seqlen]) in https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp#L655, which introduces padding again for shorter sequences.

I wonder how feasible it would be to store the softmax_lse in an unpadded format ([num_tokens, num_heads]) for the varlen kernel (at least for storing activations for the backwards pass).

Do you think that this would achieve approximately constant memory use when training with batches of fixed number of tokens? Thank you!

tridao commented 2 months ago

There's a PR for that, will be merged soon.