One of the key changes was logit soft-capping, however it's not implemented in flash-attention, and the implementations using flash-attention have to resort to the slow native pytorch variant
I think it would be useful for all to include the option to insert logit soft-capping into this library
As you probably know, yesterday Google released Gemma2 with superior performance and robustness https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf
One of the key changes was logit soft-capping, however it's not implemented in flash-attention, and the implementations using flash-attention have to resort to the slow native pytorch variant
I think it would be useful for all to include the option to insert logit soft-capping into this library
The native pytorch implementation is provided along with the new Gemma 2 model architecture at https://github.com/google/gemma_pytorch/blob/main/gemma/model.py