epfLLM / Megatron-LLM

distributed trainer for LLMs
Other
504 stars 73 forks source link

Correctness when enabling FlashAttention + Sequence Parallel at the same time? #99

Closed xingyaoww closed 4 months ago

xingyaoww commented 4 months ago

It just occurred to me recently that the current implementation of FlashAttention does not account for the cases of sequence parallelism.

For example, here the self.core_attention_flash is called (which is essentially flash_attn.flash_attn_func), however, the q, k, v passed into flash attention are not all-gathered, which means we only calculate attention on a chunk of sentences (say total seq_len is L, parallelism being 4, we only calculate attention within L/4 sequence) which will cause issues in the trained models (i.e., the trained model stop attending to previous content after L/4 tokens).

Their original implementation of ParallelAttention does not have this issue since they perform all-gather in the forward pass and reduce-scatter in the backward pass; see this issue for details.

martinjaggi commented 4 months ago

we haven't modified this aspect compared to nvidia megatron. are you saying that in nvidia megatron this has been changed since we forked the repository (around april 2023)?

xingyaoww commented 4 months ago

Yes, this is the same implementation as Megatron-LM. Never mind, I just dig in deeper to do some interactive debugging and found out the all-gather operation happens implicitly in self.query_key_value, which is a ColumnParallelLinear that helps take care of the all-gather. The current implementation should be fine. :)