pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.56k stars 166 forks source link

Does Float8Linear support Tensor Parallelism and Sequence Parallelism? #1198

Open zigzagcai opened 1 week ago

zigzagcai commented 1 week ago

We know that Transformer_Engine has support for FP8 training with data parallel + tensor parallel + sequence parallel, https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/advanced_optimizations.html#Multi-GPU-training

However, when I tried to check with the source code of swap_linear_modules and Float8Linear, and the documentations/discussion of torchao, I can only see the support for FP8 with FSDP (as far as I know).

So, does torchao also has the support for tensor parallelism and sequence parallelism with FP8Linear?

Thanks!

vkuzo commented 1 week ago

Hi @zigzagcai , we do support TP and SP implemented via DTensor (https://pytorch.org/docs/stable/distributed.tensor.parallel.html). We have designed the float8 APIs to be orthogonal to TP/SP, so for the user the workflow would be:

  1. start with high precision model
  2. convert parts of model to float8 (torchao.float8)
  3. define distributed strategy for model (https://pytorch.org/docs/stable/distributed.tensor.parallel.html)

If you are ok with distributed communications to happen in high precision, then (2) and (3) are independent. If you are interested in doing the all-gathers in low precision, then (2) and (3) interact to coordinate on low precision casting.

Here is an e2e example of (2) and (3) interacting with all-gathers happening in low precision: https://github.com/pytorch/ao/blob/4f1fc4c2d5eace4153c6e62123f5584e772fff4c/test/float8/test_dtensor.py#L206 . If you were to take that example and replace Float8ColwiseParallel with ColwiseParallel, etc, then you'd get to a point where (2) and (3) are independent.

A full e2e training example of how all of this fits together is https://github.com/pytorch/torchtitan.

Let me keep this issue open to track adding more information to the README about how this works. Thanks for the question.