Closed xingyaoww closed 4 months ago
we haven't modified this aspect compared to nvidia megatron. are you saying that in nvidia megatron this has been changed since we forked the repository (around april 2023)?
Yes, this is the same implementation as Megatron-LM. Never mind, I just dig in deeper to do some interactive debugging and found out the all-gather
operation happens implicitly in self.query_key_value
, which is a ColumnParallelLinear
that helps take care of the all-gather. The current implementation should be fine. :)
It just occurred to me recently that the current implementation of FlashAttention does not account for the cases of sequence parallelism.
For example, here the
self.core_attention_flash
is called (which is essentiallyflash_attn.flash_attn_func
), however, theq, k, v
passed into flash attention are not all-gathered, which means we only calculate attention on a chunk of sentences (say total seq_len isL
, parallelism being 4, we only calculate attention withinL/4
sequence) which will cause issues in the trained models (i.e., the trained model stop attending to previous content afterL/4
tokens).Their original implementation of
ParallelAttention
does not have this issue since they perform all-gather in the forward pass and reduce-scatter in the backward pass; see this issue for details.