bigcode-project / Megatron-LM

Ongoing research training transformer models at scale
Other
376 stars 49 forks source link

Reduce the tensor-parallel KV output grads #37

Closed jlamypoirier closed 1 year ago

jlamypoirier commented 1 year ago

Edit: the new version of #36 should be better than this. Alternative to #36 Fix the kv gradient computation in tensor-parallel. Add a reduction in the kv output grads and remove the kv parameter grad reduction.

Compared to #36, this is faster and more closely matches the mathematical justification, which comes from the Q x K multiplication:

A [batch, num_heads * sq, sk] = Q [batch, num_heads * sq, head_size] x K^T [batch, head_size, sk]
G_K [batch, head_size, sk] = Q^T [batch, head_size, num_heads * sq] x G_A [batch, num_heads * sq, sk]

The sum over num_heads is where we need the reduction.

Valid loss after 5000 steps, runtime / average step time for a small model:

Method                            Loss        Runtime     Avg
MHA:                              2.604475    2820 s      564 ms
MQA, TP=1:                        2.641582    2303 s      461 ms 
MQA, before fix:                  3.740988    2538 s      508 ms
MQA, after this fix:              2.640481    2528 s      506 ms
MQA, with PR 36, v1:              2.640432    2602 s      520 ms
MQA, with PR 36, v2               2.640481    2548 s      510 ms
MQA, sequence parallel:           2.640036    2654 s      531 ms
MQA, sequence parallel, PR 36 v2: 2.638824    2629 s      526 ms

So both implementations work and are almost identical, but this PR is faster (difference would be smaller in real scenario because this is a very small model with a huge TP overhead.).

jlamypoirier commented 1 year ago

36 is now better