Edit: the new version of #36 should be better than this.
Alternative to #36
Fix the kv gradient computation in tensor-parallel.
Add a reduction in the kv output grads and remove the kv parameter grad reduction.
Compared to #36, this is faster and more closely matches the mathematical justification, which comes from the Q x K multiplication:
The sum over num_heads is where we need the reduction.
Valid loss after 5000 steps, runtime / average step time for a small model:
Method Loss Runtime Avg
MHA: 2.604475 2820 s 564 ms
MQA, TP=1: 2.641582 2303 s 461 ms
MQA, before fix: 3.740988 2538 s 508 ms
MQA, after this fix: 2.640481 2528 s 506 ms
MQA, with PR 36, v1: 2.640432 2602 s 520 ms
MQA, with PR 36, v2 2.640481 2548 s 510 ms
MQA, sequence parallel: 2.640036 2654 s 531 ms
MQA, sequence parallel, PR 36 v2: 2.638824 2629 s 526 ms
So both implementations work and are almost identical, but this PR is faster (difference would be smaller in real scenario because this is a very small model with a huge TP overhead.).
Edit: the new version of #36 should be better than this. Alternative to #36 Fix the kv gradient computation in tensor-parallel. Add a reduction in the kv output grads and remove the kv parameter grad reduction.
Compared to #36, this is faster and more closely matches the mathematical justification, which comes from the Q x K multiplication:
The sum over num_heads is where we need the reduction.
Valid loss after 5000 steps, runtime / average step time for a small model:
So both implementations work and are almost identical, but this PR is faster (difference would be smaller in real scenario because this is a very small model with a huge TP overhead.).