lucidrains / flash-attention-jax

Implementation of Flash Attention in Jax
MIT License
193 stars 23 forks source link

support for per-head scales for cosine sim attention #6

Open GallagherCommaJack opened 2 years ago

GallagherCommaJack commented 2 years ago

usually with cosine-sim models I'd train with learned per-head scales for the attention logits, I guess I can get this from multiplying by q & k by sqrt(scales) before the dot product but that's probably less stable

lucidrains commented 2 years ago

@GallagherCommaJack try keeping it at a constant fixed scale of 10

it worked well for me as well as Boris (for Craiyon) In fact, it is his best run

GallagherCommaJack commented 2 years ago

I am trying to use this with models I've already spent a decent amount of compute training, would be a lot more work to retrain from scratch

GallagherCommaJack commented 2 years ago

could of course tune with a constant scale but that seems like a worse option than relying on xla to fuse here since the non-cosine-sim version should be drop-in compatible.

lucidrains commented 2 years ago

@GallagherCommaJack ahh, i don't know if i can support that, i'm going all-in on fixed scale of 10 https://wandb.ai/dalle-mini/dalle-mini/reports/Fix-Swin-v2--VmlldzoyNDA4Mzc3 (blue curve)

GallagherCommaJack commented 2 years ago

hmm really? the scales are just a pointwise op between the dot product and logits in a normal implementation. why does flash attention make that harder?

lucidrains commented 2 years ago

@GallagherCommaJack it isn't difficult, it is just unnecessary

you can always fork it and add it yourself, if you need it for your specific pretrained network

cosine sim attention isn't even faithful to the original flash attention. it is kind of a new direction i'm taking it https://github.com/lucidrains/flash-cosine-sim-attention