Open maximzubkov opened 8 months ago
I have not ran any deep investigations, but currently I expect that it is not worth it:
flax.linen.remat
one can guarantee that no more than 1 (or 2) attention matrices are materialized at any moment in time. So memory-wise it is only helpful when scaling beyond 4k tokens, which is not a typical vision scenario.Thank you for the detailed explanation! I'll give it a try in my experiments
Hello,
big_vision
team!Thanks for your work on the repository. Looking through the code I noticed that ViT is using classical attention (see line 91 of ViT implementation). It seems like it should be relatively easy to replace current attention implementation with a memory-efficient alternative from flaxformer (line 595 in flaxformer) just passing
dot_product_attention_multihead
asattention_fn
innn.MultiHeadDotProductAttnetion
(line 221 in flax). I think such improvement is worth considering since Flesh Attention authors reported up to x2.4 speedup on long sequences (1k-4k tokens)What do you think about it? Are there any limitations that make efficient attention integration harder than it seems? I'm not experienced in Jax, so your feedback would be very appreciated