Open yongyanrao opened 1 year ago
In the current scheme, Transformer Engine modules use standard parameter tensors in standard dtypes (FP32/BF16/FP16). Optimizers typically require higher precision than FP8 to get good learning behavior. I don't see anything that would disrupt FSDP and I've been able to get TE and FSDP working together in some quick experiments.
Hi Tim, thank you for the response. Did you try FP8 + FSDP yet?
Hello, @timmoon10 which FSDP did you use ? Fairscale's ?
Hi, I was referring to pytorch's FullyShardedDataParallel.
Thank you !
Yep, I used PyTorch FSDP with TE FP8. Be advised I haven't done full convergence experiments, just some basic sanity checking.
What is the recommendation for MixedPrecision when using FP8 with FSDP @timmoon10 ?
from torch.distributed.fsdp import MixedPrecision
precision = ?
MixedPrecision(
param_dtype=precision, # should we be forcing dtype here?
reduce_dtype=torch.float32, # reduce in FP32 as with AMP?
buffer_dtype=precision, # buffers in FP32?
cast_forward_inputs=is_amp_enabled, # should we cast the forward samples?
)
Transformer Engine manages FP8 casting internally (see transformer_engine.pytorch.fp8_autocast
) and it can run into problems when combined with other mixed precision tools like torch.autocast
or torch.distributed.fsdp.MixedPrecision
. For the moment, FSDP mixed precision should only be used to handle the case where param_dtype
and reduce_dtype
do not match, e.g.:
from torch.distributed.fsdp import MixedPrecision
import transformer_engine.pytorch as te
precision = torch.bfloat16
model = te.Linear(32, 32, params_dtype=precision)
fsdp_mixed_precision = MixedPrecision(
param_dtype=precision,
reduce_dtype=torch.float32,
buffer_dtype=precision,
)
Edit: On second thought, I'm skeptical about this post. I don't remember my main concerns and I've been able to have fp8_autocast
working with torch.autocast
and torch.distributed.fsdp.MixedPrecision
. See https://github.com/NVIDIA/TransformerEngine/issues/438#issuecomment-1743693330.
@timmoon10 I think the community would greatly benefit from an e2e example on how to train a 7B model on H100 node(s). This is the main purpose of this lib - to train large models quickly. Yet there is not a single example on how to do it - instead we have a MNIST example... Ideally, a simple pure PyTorch+TE file without any "framework" dependencies like compose or fabric.
@PiotrDabkowski, please take a look at NeMo, which provides the exact scripts and tools to be able to do this.
@ksivaman : AFAIK NeMo uses tensor parallelism which requires manual changes to layers and arch whereas FSDP is pretty generic and more easy to use.
@jramapuram were you able to train FSDP with FP8?
@yongyanrao #596 recently added support for deferred initialization via device='meta'
to improve FSDP support for large models. This feature delays memory allocation on device until the FSDP wrap internally calls reset_parameters()
after sharding the model weights, which ensures that TE modules are not duplicated on every device upon initialization.
We also introduced a small FSDP example in the repo as part of this effort to demonstrate how to use FSDP with TE modules. It works out of the box with fp8_autocast()
, with the requirement that fp8_autocast(...,fp8_group=...)
gets the same process group as FSDP (both will default to the world group if none given).
@denera Thank you for the feature! This will be super helpful for us too. Just a question, are there any additional steps needed to take to make primary_weights_in_fp8 work with FSDP? Or should it work out of the box?
Thank you!
@denizokt fp8_model_init()
is not supported with FSDP at the moment.
NCCL itself does not support 8-bit floats (see this discussion for more detail) and FSDP needs to upcast TE Fp8 weights to Fp16 for all-gathers and reduce-scatters, which it cannot do until PyTorch starts natively supporting Fp8 tensors.
The workaround for this limitation is what already happens in TE when primary_weights_in_fp8 = False
. TE modules maintain their own Fp8 weight copies, update them with the primary Fp16/Bf16 weights during the forward pass, and stash the Fp8 transposed-weights into the PyTorch autograd context to re-use it during the backward pass.
@denera Thank you for the information, this makes a lot of sense.
I was wondering if PyTorch's FullyShardedDataParallel (FSDP) is supported by TransformerEngine , especially if FP8 can work with FSDP. Thank you in advance.