Open GallagherCommaJack opened 2 years ago
@GallagherCommaJack try keeping it at a constant fixed scale of 10
it worked well for me as well as Boris (for Craiyon) In fact, it is his best run
I am trying to use this with models I've already spent a decent amount of compute training, would be a lot more work to retrain from scratch
could of course tune with a constant scale but that seems like a worse option than relying on xla to fuse here since the non-cosine-sim version should be drop-in compatible.
@GallagherCommaJack ahh, i don't know if i can support that, i'm going all-in on fixed scale of 10 https://wandb.ai/dalle-mini/dalle-mini/reports/Fix-Swin-v2--VmlldzoyNDA4Mzc3 (blue curve)
hmm really? the scales are just a pointwise op between the dot product and logits in a normal implementation. why does flash attention make that harder?
@GallagherCommaJack it isn't difficult, it is just unnecessary
you can always fork it and add it yourself, if you need it for your specific pretrained network
cosine sim attention isn't even faithful to the original flash attention. it is kind of a new direction i'm taking it https://github.com/lucidrains/flash-cosine-sim-attention
usually with cosine-sim models I'd train with learned per-head scales for the attention logits, I guess I can get this from multiplying by
q
&k
bysqrt(scales)
before the dot product but that's probably less stable