Open zigzagcai opened 1 week ago
Hi @zigzagcai , we do support TP and SP implemented via DTensor (https://pytorch.org/docs/stable/distributed.tensor.parallel.html). We have designed the float8 APIs to be orthogonal to TP/SP, so for the user the workflow would be:
If you are ok with distributed communications to happen in high precision, then (2) and (3) are independent. If you are interested in doing the all-gathers in low precision, then (2) and (3) interact to coordinate on low precision casting.
Here is an e2e example of (2) and (3) interacting with all-gathers happening in low precision: https://github.com/pytorch/ao/blob/4f1fc4c2d5eace4153c6e62123f5584e772fff4c/test/float8/test_dtensor.py#L206 . If you were to take that example and replace Float8ColwiseParallel
with ColwiseParallel
, etc, then you'd get to a point where (2) and (3) are independent.
A full e2e training example of how all of this fits together is https://github.com/pytorch/torchtitan.
Let me keep this issue open to track adding more information to the README about how this works. Thanks for the question.
We know that
Transformer_Engine
has support for FP8 training withdata parallel + tensor parallel + sequence parallel
, https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/advanced_optimizations.html#Multi-GPU-trainingHowever, when I tried to check with the source code of swap_linear_modules and Float8Linear, and the documentations/discussion of torchao, I can only see the support for FP8 with FSDP (as far as I know).
So, does
torchao
also has the support for tensor parallelism and sequence parallelism withFP8Linear
?Thanks!