google-research / big_vision

Official codebase used to develop Vision Transformer, SigLIP, MLP-Mixer, LiT and more.
Apache License 2.0
1.99k stars 138 forks source link

Memory Efficient Attention integration #59

Open maximzubkov opened 8 months ago

maximzubkov commented 8 months ago

Hello, big_vision team!

Thanks for your work on the repository. Looking through the code I noticed that ViT is using classical attention (see line 91 of ViT implementation). It seems like it should be relatively easy to replace current attention implementation with a memory-efficient alternative from flaxformer (line 595 in flaxformer) just passing dot_product_attention_multihead as attention_fn in nn.MultiHeadDotProductAttnetion (line 221 in flax). I think such improvement is worth considering since Flesh Attention authors reported up to x2.4 speedup on long sequences (1k-4k tokens)

What do you think about it? Are there any limitations that make efficient attention integration harder than it seems? I'm not experienced in Jax, so your feedback would be very appreciated

akolesnikoff commented 8 months ago

I have not ran any deep investigations, but currently I expect that it is not worth it:

  1. Speed-wise XLA is already supposed to do the optimizations that flash-attention does. However, if you try it and observe that XLA is doing worse job that manual implementation of FLASH attention, we may have a look into it. It can also be different for TPUs vs GPUs.
  2. Memory-wise, a single attention matrix for 4k tokens weighs 64 MB. And by wrapping the whole transformer block inside flax.linen.remat one can guarantee that no more than 1 (or 2) attention matrices are materialized at any moment in time. So memory-wise it is only helpful when scaling beyond 4k tokens, which is not a typical vision scenario.
maximzubkov commented 8 months ago

Thank you for the detailed explanation! I'll give it a try in my experiments