Closed le1nux closed 3 months ago
According to the documentation, this is an implementation that is equivalent to flash attention2 but directly shipped with PyTorch as a beta feature. We even saw a 10% throughput increase with the new PyTorch implementation on the DGX boxes.
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
cc @fromm-m correct me if I'm mistaken here... :-)
According to the documentation, this is an implementation that is equivalent to flash attention2 but directly shipped with PyTorch as a beta feature. We even saw a 10% throughput increase with the new PyTorch implementation on the DGX boxes.
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
cc @fromm-m correct me if I'm mistaken here... :-)
Ok, but note that as part of #74, we did the opposite change and went from scaled_dot_product_attention
to flash_attn
, see https://github.com/Modalities/modalities/pull/74/commits/7d27b59d9e71ecdba40ea8a72521ad38d361aea7. I wonder what happened to the GQA part, in particular the previously used repeat_kv
function which took care of the mapping between query heads and key / value heads. Isn't that part missing now? The tests in tests/models/test_causal_self_attention.py
fail, and I think this might be because pytorch's scaled_dot_product_attention
does not internally implement GQA, unlike flash_attn
(see here). Or am I missing something?
According to the documentation, this is an implementation that is equivalent to flash attention2 but directly shipped with PyTorch as a beta feature. We even saw a 10% throughput increase with the new PyTorch implementation on the DGX boxes. https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html cc @fromm-m correct me if I'm mistaken here... :-)
Ok, but note that as part of #74, we did the opposite change and went from
scaled_dot_product_attention
toflash_attn
, see 7d27b59. I wonder what happened to the GQA part, in particular the previously usedrepeat_kv
function which took care of the mapping between query heads and key / value heads. Isn't that part missing now? The tests intests/models/test_causal_self_attention.py
fail, and I think this might be because pytorch'sscaled_dot_product_attention
does not internally implement GQA, unlikeflash_attn
(see here). Or am I missing something?
You are right, flash-attention in PyTorch still does not support GQA, with the newest changes we are now back to flash_attn from Daio for now.
we experienced loss convergence to almost 0 while inference performance was rather low. We figured, there was an information leakage bug due to view and reshape operations that lead to parts of embeddings of different tokens being mixed.