NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
2k stars 332 forks source link

FSDP support #401

Open yongyanrao opened 1 year ago

yongyanrao commented 1 year ago

I was wondering if PyTorch's FullyShardedDataParallel (FSDP) is supported by TransformerEngine , especially if FP8 can work with FSDP. Thank you in advance.

timmoon10 commented 1 year ago

In the current scheme, Transformer Engine modules use standard parameter tensors in standard dtypes (FP32/BF16/FP16). Optimizers typically require higher precision than FP8 to get good learning behavior. I don't see anything that would disrupt FSDP and I've been able to get TE and FSDP working together in some quick experiments.

yongyanrao commented 1 year ago

Hi Tim, thank you for the response. Did you try FP8 + FSDP yet?

MatthieuToulemont commented 1 year ago

Hello, @timmoon10 which FSDP did you use ? Fairscale's ?

yongyanrao commented 1 year ago

Hi, I was referring to pytorch's FullyShardedDataParallel.

MatthieuToulemont commented 1 year ago

Thank you !

timmoon10 commented 1 year ago

Yep, I used PyTorch FSDP with TE FP8. Be advised I haven't done full convergence experiments, just some basic sanity checking.

jramapuram commented 1 year ago

What is the recommendation for MixedPrecision when using FP8 with FSDP @timmoon10 ?

from torch.distributed.fsdp import MixedPrecision

precision = ?

MixedPrecision(
    param_dtype=precision,  # should we be forcing dtype here?
    reduce_dtype=torch.float32,  # reduce in FP32 as with AMP?
    buffer_dtype=precision,  # buffers in FP32?
    cast_forward_inputs=is_amp_enabled,  # should we cast the forward samples?
)
timmoon10 commented 1 year ago

Transformer Engine manages FP8 casting internally (see transformer_engine.pytorch.fp8_autocast) and it can run into problems when combined with other mixed precision tools like torch.autocast or torch.distributed.fsdp.MixedPrecision. For the moment, FSDP mixed precision should only be used to handle the case where param_dtype and reduce_dtype do not match, e.g.:

from torch.distributed.fsdp import MixedPrecision
import transformer_engine.pytorch as te

precision = torch.bfloat16
model = te.Linear(32, 32, params_dtype=precision)

fsdp_mixed_precision = MixedPrecision(
    param_dtype=precision,
    reduce_dtype=torch.float32,
    buffer_dtype=precision,
)

Edit: On second thought, I'm skeptical about this post. I don't remember my main concerns and I've been able to have fp8_autocast working with torch.autocast and torch.distributed.fsdp.MixedPrecision. See https://github.com/NVIDIA/TransformerEngine/issues/438#issuecomment-1743693330.

PiotrDabkowski commented 1 year ago

@timmoon10 I think the community would greatly benefit from an e2e example on how to train a 7B model on H100 node(s). This is the main purpose of this lib - to train large models quickly. Yet there is not a single example on how to do it - instead we have a MNIST example... Ideally, a simple pure PyTorch+TE file without any "framework" dependencies like compose or fabric.

ksivaman commented 1 year ago

@PiotrDabkowski, please take a look at NeMo, which provides the exact scripts and tools to be able to do this.

jramapuram commented 1 year ago

@ksivaman : AFAIK NeMo uses tensor parallelism which requires manual changes to layers and arch whereas FSDP is pretty generic and more easy to use.

naveenkumarmarri commented 1 year ago

@jramapuram were you able to train FSDP with FP8?

denera commented 10 months ago

@yongyanrao #596 recently added support for deferred initialization via device='meta' to improve FSDP support for large models. This feature delays memory allocation on device until the FSDP wrap internally calls reset_parameters() after sharding the model weights, which ensures that TE modules are not duplicated on every device upon initialization.

We also introduced a small FSDP example in the repo as part of this effort to demonstrate how to use FSDP with TE modules. It works out of the box with fp8_autocast(), with the requirement that fp8_autocast(...,fp8_group=...) gets the same process group as FSDP (both will default to the world group if none given).

denizokt commented 10 months ago

@denera Thank you for the feature! This will be super helpful for us too. Just a question, are there any additional steps needed to take to make primary_weights_in_fp8 work with FSDP? Or should it work out of the box?

Thank you!

denera commented 10 months ago

@denizokt fp8_model_init() is not supported with FSDP at the moment.

NCCL itself does not support 8-bit floats (see this discussion for more detail) and FSDP needs to upcast TE Fp8 weights to Fp16 for all-gathers and reduce-scatters, which it cannot do until PyTorch starts natively supporting Fp8 tensors.

The workaround for this limitation is what already happens in TE when primary_weights_in_fp8 = False. TE modules maintain their own Fp8 weight copies, update them with the primary Fp16/Bf16 weights during the forward pass, and stash the Fp8 transposed-weights into the PyTorch autograd context to re-use it during the backward pass.

denizokt commented 10 months ago

@denera Thank you for the information, this makes a lot of sense.