JonasGeiping / cramming

Cramming the training of a (BERT-type) language model into limited compute.
MIT License
1.3k stars 100 forks source link

Flash Attention #48

Closed versae closed 3 months ago

versae commented 4 months ago

Hi, great project!

Are there any plans to implement/support Flash attention 1, 2, or 3 or SDPA.

Cheers.

JonasGeiping commented 4 months ago

Attention backends are selected here: https://github.com/JonasGeiping/cramming/blob/196c9912d8c5b06a05e9a58edd1521e3d38f7c0c/cramming/architectures/attention.py#L18 and Pytorch / SDPA is an option, although gains through Flash attention are small for the default sequence length of 128. Happy to accept PRs for more backends.