Closed wanchaol closed 3 months ago
@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@wanchaol merged this pull request in pytorch-labs/float8_experimental@7c7cbae21d76d7e1f1f326f2701c08243171ab67.
when applying Sequence Parallel to a module with more than 2 linear layers for input proj, we often want to transform from Shard to Replicate once (allgather once) and then reuse the allgathered result, for fp8 we would need to do the casting before the shard -> replicate so that we can perform the fp8 allgather.
This PR subclasses the PrepareModuleInput to add the fp8 casting logic to make sure we run the fp8 allgather instead of bf16 allgather then do the casting for computation.
Also adjust the test cases to test the real ffn case for sequence parallel
torchtitan perf benchmarks (8 H100 devgpu, Llama3 8b, 2-way DP, 4-way TP):
So even in eager we got around 20% perf improvement with every allgather runs in fp8, and compiled fp8 allgather perf is more than doubled (102% more WPS) :)