Modalities / modalities

Modalities, a PyTorch-native framework for distributed and reproducible foundation model training.
MIT License
59 stars 5 forks source link

Disable Flash Attention for inference #162

Open rrutmann opened 3 months ago

rrutmann commented 3 months ago

Flash Attention can only be used with fp16 and bf16, not with fp32. Therefore, we should make flash attention optional in our codebase so that one can deactivate it during inference in exchange for higher precision.

flxst commented 3 months ago

Optional flash attention has been implemented in #141 (and the sub-PR #138). Once the changes are merged into main, you will be able to choose between different attention mechanisms (manual, pytorch flash, dao flash) in the training config. In principle, it should also work to just choose a different attention mechanism for inference then, but I don't think this has explicitly been tested yet.