bigcode-project / Megatron-LM

Ongoing research training transformer models at scale
Other
374 stars 49 forks source link

Kv grad allreduce v2 #39

Closed jlamypoirier closed 1 year ago

jlamypoirier commented 1 year ago

Pushing my changes to #36 as a separate PR since I can't push to @thomasw21 's branch. Same as #36, but reduced the diff and improved the comments.

Method                            Loss        Runtime     Avg
MHA:                              2.604475    2820 s      564 ms
MQA, TP=1:                        2.641582    2303 s      461 ms 
MQA, before fix:                  3.740988    2538 s      508 ms
MQA, PR 36, first version:        2.640432    2602 s      520 ms
MQA, PR 37:                       2.640481    2528 s      506 ms
MQA, PR 36:                       2.640481    2548 s      510 ms
MQA, PR 39:                       2.640481    2538 s      508 ms
MQA, PR 37, sequence parallel:    2.640036    2654 s      531 ms
MQA, PR 36, sequence parallel:    2.638824    2629 s      526 ms
MQA, PR 39, sequence parallel:    2.638824    2622 s      524 ms

So this is identical to PR 36 other than statistical variation on the runtime, same as PR 37 with TP and marginally faster than PR 37 with SP (though it's hard to distinguish the speedup from statistical variation)