Open vgoklani opened 2 months ago
Hi @vgoklani -- TE modules can be initialized under the with te.fp8_model_init():
context to allocate their primary weights in FP8 (as te.Float8Tensor
s) instead of allocating at a higher precision and maintaining separate FP8 buffers for compute.
I don't believe anyone has tried this in practice, but at least in principle, FSDP2's per-parameter sharding should work out-of-the-box with the torch.uint8
data underneath our te.Float8Tensor
s.
There are two things to be mindful of here:
precompute_float8_dynamic_scale_for_fsdp(model)
API from the linked example because TE already does this internally. You simply need to pass the process group for amax reductions (typically global/world group) into the te.fp8_autocast()
context.te.fp8_model_init()
is intended to be used with higher precision "master" copies of the model parameters in the optimizer. If you experiment with TE + FSDP2, please share your experiences. We already support PyTorch's native FSDP but this involves TE modules carrying extra FP8 buffers for the compute while FSDP communication remains in higher precision. It would be great to extend our FSDP support to te.fp8_model_init()
+ FSDP2.
Adding to this, FSDP support should just be a matter of implementing fsdp_pre_all_gather
and fsdp_post_all_gather
methods in Float8Tensor
, at least in principle.
FSDP2 supports all-gather using FP8:
https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323
Wondering if we could do this directly using TransformerEngine instead of torch-ao?
Thanks!