bigcode-project / Megatron-LM

Ongoing research training transformer models at scale
Other
373 stars 48 forks source link

Double-check the code for key/value gradient reduction in the case of MQA, when tensor-model-parallel > 1, and for distributed optim #11

Closed RaymondLi0 closed 1 year ago

RaymondLi0 commented 1 year ago

Looking again at this code: https://github.com/bigcode-project/Megatron-LM/blob/multi-query-attention/megatron/optimizer/optimizer.py#L269 1 - Aren't the gradients of the biases missing from the reduction?

2 - The distributed optimizer misses this reduction: https://github.com/bigcode-project/Megatron-LM/blob/8169dec7a78a84cefd65ad69c5060f2a1fba15a3/megatron/optimizer/distrib_optimizer.py#L522

RaymondLi0 commented 1 year ago

Also, as Joel pointed out, self.models[0] would not work with interleaved stages https://github.com/bigcode-project/Megatron-LM/blob/multi-query-attention/megatron/optimizer/optimizer.py#L271

jlamypoirier commented 1 year ago

There is also a crash when combining with sequence parallelism:

2022-12-09 04:21:00,367 (worker_1) :   File "/app/megatron/model/transformer.py", line 415, in forward
2022-12-09 04:21:00,367 (worker_1) :     attention_probs = self.scale_mask_softmax(attention_scores,
2022-12-09 04:21:00,367 (worker_1) :   File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1129, in _call_impl
2022-12-09 04:21:00,367 (worker_1) :     return forward_call(*input, **kwargs)
2022-12-09 04:21:00,367 (worker_1) :   File "/app/megatron/model/fused_softmax.py", line 161, in forward
2022-12-09 04:21:00,367 (worker_1) :     return self.forward_fused_softmax(input, mask)
2022-12-09 04:21:00,367 (worker_1) :   File "/app/megatron/model/fused_softmax.py", line 192, in forward_fused_softmax
2022-12-09 04:21:00,367 (worker_1) :     assert sq == sk, "causal mask is only for self attention"
2022-12-09 04:21:00,367 (worker_1) : AssertionError: causal mask is only for self attention

I did some investigating and found that the key layer skips the sequence-parallel gather because it's a normal linear layer instead of the tensor-parallel one, so it ends up with the wrong sequence length.

RaymondLi0 commented 1 year ago

Nice catch! We could add a check earlier in the code saying that sequence-parallel is not yet supported with MQA, or implement a non-tensor-parallel layer that supports sequence-parallelism?

jlamypoirier commented 1 year ago

Nice catch! We could add a check earlier in the code saying that sequence-parallel is not yet supported with MQA, or implement a non-tensor-parallel layer that supports sequence-parallelism?

I added a tentative fix to #12