NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
Apache License 2.0
2k stars 332 forks source link

FSDP: How to do all-gather using FP8? #1188

Open vgoklani opened 2 months ago

vgoklani commented 2 months ago

FSDP2 supports all-gather using FP8:


Wondering if we could do this directly using TransformerEngine instead of torch-ao?


denera commented 2 months ago

Hi @vgoklani -- TE modules can be initialized under the with te.fp8_model_init(): context to allocate their primary weights in FP8 (as te.Float8Tensors) instead of allocating at a higher precision and maintaining separate FP8 buffers for compute.

I don't believe anyone has tried this in practice, but at least in principle, FSDP2's per-parameter sharding should work out-of-the-box with the torch.uint8 data underneath our te.Float8Tensors.

There are two things to be mindful of here:

  1. You would not use the precompute_float8_dynamic_scale_for_fsdp(model) API from the linked example because TE already does this internally. You simply need to pass the process group for amax reductions (typically global/world group) into the te.fp8_autocast() context.
  2. In the absence of native FP8 support in PyTorch, you cannot apply the optimizer step directly onto the FP8 model parameters. Consequently, te.fp8_model_init() is intended to be used with higher precision "master" copies of the model parameters in the optimizer.

If you experiment with TE + FSDP2, please share your experiences. We already support PyTorch's native FSDP but this involves TE modules carrying extra FP8 buffers for the compute while FSDP communication remains in higher precision. It would be great to extend our FSDP support to te.fp8_model_init() + FSDP2.

timmoon10 commented 2 months ago

Adding to this, FSDP support should just be a matter of implementing fsdp_pre_all_gather and fsdp_post_all_gather methods in Float8Tensor, at least in principle.