OpenNMT / OpenNMT-py

Open Source Neural Machine Translation and (Large) Language Models in PyTorch
https://opennmt.net/
MIT License
6.67k stars 2.24k forks source link

fixed masked flash attention #2589

Closed l-k-11235 closed 1 week ago

l-k-11235 commented 1 month ago

This PR proposes a fix for the flash attention path in the multi-head attention module. The flash attention block doesn't support the left padding mask, so we apply it upstream. https://github.com/Dao-AILab/flash-attention/issues/649

We apply the key_pad_mask to the values contained in the KV-cache at the first step only, for all scenari (standard, flash, sdpa mechanisms).

vince62s commented 1 month ago

I think it would be great to explain in this PR when and how the key_pad_mask needs to be used, and being clear in the different scenarii (standard, flash, sdpa mechanisms)