Pushing my changes to #36 as a separate PR since I can't push to @thomasw21 's branch.
Same as #36, but reduced the diff and improved the comments.
Method Loss Runtime Avg
MHA: 2.604475 2820 s 564 ms
MQA, TP=1: 2.641582 2303 s 461 ms
MQA, before fix: 3.740988 2538 s 508 ms
MQA, PR 36, first version: 2.640432 2602 s 520 ms
MQA, PR 37: 2.640481 2528 s 506 ms
MQA, PR 36: 2.640481 2548 s 510 ms
MQA, PR 39: 2.640481 2538 s 508 ms
MQA, PR 37, sequence parallel: 2.640036 2654 s 531 ms
MQA, PR 36, sequence parallel: 2.638824 2629 s 526 ms
MQA, PR 39, sequence parallel: 2.638824 2622 s 524 ms
So this is identical to PR 36 other than statistical variation on the runtime, same as PR 37 with TP and marginally faster than PR 37 with SP (though it's hard to distinguish the speedup from statistical variation)
Pushing my changes to #36 as a separate PR since I can't push to @thomasw21 's branch. Same as #36, but reduced the diff and improved the comments.
So this is identical to PR 36 other than statistical variation on the runtime, same as PR 37 with TP and marginally faster than PR 37 with SP (though it's hard to distinguish the speedup from statistical variation)