pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

add PrepareFloat8ModuleInput for sequence parallel #275

Closed wanchaol closed 3 months ago

wanchaol commented 3 months ago

when applying Sequence Parallel to a module with more than 2 linear layers for input proj, we often want to transform from Shard to Replicate once (allgather once) and then reuse the allgathered result, for fp8 we would need to do the casting before the shard -> replicate so that we can perform the fp8 allgather.

This PR subclasses the PrepareModuleInput to add the fp8 casting logic to make sure we run the fp8 allgather instead of bf16 allgather then do the casting for computation.

Also adjust the test cases to test the real ffn case for sequence parallel

torchtitan perf benchmarks (8 H100 devgpu, Llama3 8b, 2-way DP, 4-way TP):

So even in eager we got around 20% perf improvement with every allgather runs in fp8, and compiled fp8 allgather perf is more than doubled (102% more WPS) :)

facebook-github-bot commented 3 months ago

@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 months ago

@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 months ago

@wanchaol merged this pull request in pytorch-labs/float8_experimental@7c7cbae21d76d7e1f1f326f2701c08243171ab67.