Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
11.79k stars 1.04k forks source link

Logit soft-capping #1016

Open kabachuha opened 2 days ago

kabachuha commented 2 days ago

As you probably know, yesterday Google released Gemma2 with superior performance and robustness https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf

One of the key changes was logit soft-capping, however it's not implemented in flash-attention, and the implementations using flash-attention have to resort to the slow native pytorch variant

I think it would be useful for all to include the option to insert logit soft-capping into this library

The native pytorch implementation is provided along with the new Gemma 2 model architecture at https://github.com/google/gemma_pytorch/blob/main/gemma/model.py

tridao commented 2 days ago

Right, we just need someone to implement it :D

lucidrains commented 1 day ago

@kabachuha @tridao i can take a look at this next week