Modalities / modalities

A framework for training multimodal foundation models.
MIT License
39 stars 3 forks source link

fix: fixed tensor reshape operations #115

Closed le1nux closed 3 months ago

le1nux commented 3 months ago

we experienced loss convergence to almost 0 while inference performance was rather low. We figured, there was an information leakage bug due to view and reshape operations that lead to parts of embeddings of different tokens being mixed.

le1nux commented 3 months ago

According to the documentation, this is an implementation that is equivalent to flash attention2 but directly shipped with PyTorch as a beta feature. We even saw a 10% throughput increase with the new PyTorch implementation on the DGX boxes.

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

cc @fromm-m correct me if I'm mistaken here... :-)

flxst commented 3 months ago

According to the documentation, this is an implementation that is equivalent to flash attention2 but directly shipped with PyTorch as a beta feature. We even saw a 10% throughput increase with the new PyTorch implementation on the DGX boxes.

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

cc @fromm-m correct me if I'm mistaken here... :-)

Ok, but note that as part of #74, we did the opposite change and went from scaled_dot_product_attention to flash_attn, see https://github.com/Modalities/modalities/pull/74/commits/7d27b59d9e71ecdba40ea8a72521ad38d361aea7. I wonder what happened to the GQA part, in particular the previously used repeat_kv function which took care of the mapping between query heads and key / value heads. Isn't that part missing now? The tests in tests/models/test_causal_self_attention.py fail, and I think this might be because pytorch's scaled_dot_product_attention does not internally implement GQA, unlike flash_attn (see here). Or am I missing something?

fromm-m commented 3 months ago

According to the documentation, this is an implementation that is equivalent to flash attention2 but directly shipped with PyTorch as a beta feature. We even saw a 10% throughput increase with the new PyTorch implementation on the DGX boxes. https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html cc @fromm-m correct me if I'm mistaken here... :-)

Ok, but note that as part of #74, we did the opposite change and went from scaled_dot_product_attention to flash_attn, see 7d27b59. I wonder what happened to the GQA part, in particular the previously used repeat_kv function which took care of the mapping between query heads and key / value heads. Isn't that part missing now? The tests in tests/models/test_causal_self_attention.py fail, and I think this might be because pytorch's scaled_dot_product_attention does not internally implement GQA, unlike flash_attn (see here). Or am I missing something?

You are right, flash-attention in PyTorch still does not support GQA, with the newest changes we are now back to flash_attn from Daio for now.