Closed alpha0422 closed 7 months ago
@timmoon10, I like the idea of drop-in optimizer replacement. Right now, distributed fused adam sets _step_supports_amp_scaling
, so scaler.unscale_(optim)
or optim.unscale_grads(grad_scaler=scaler)
won't be called from PyTorch or PyTorch Lightning, because the assumption of _step_supports_amp_scaling
is the gradient unscaling will be done in the optimizer step function, thus gradient clipping need to be delayed to the optimizer step function too.
To support the idea you mentioned, I need _step_supports_amp_scaling
need to be removed, but then I think it will break other use cases, and it will decrease the performance because gradient unscaling is explicit and not fused with the step kernel.
I've implemented my proposed API at https://github.com/timmoon10/apex/commit/0fa8e3a3a231e655f7e1690dfaf9f2f7c53be4e6, although I haven't been able to test yet.
NeMo GPT avoided these issues because it implemented a custom GradScaler
that called DistributedFusedAdam.unscale_grads
within GradScaler.unscale_
: https://github.com/NVIDIA/NeMo/blob/c5738263d8b4bedb0957374116d3e90746a51c37/nemo/collections/nlp/parts/nlp_overrides.py#L1235. See https://github.com/NVIDIA/apex/pull/1512 and https://github.com/NVIDIA/NeMo/pull/4900.
_step_supports_amp_scaling
is needed because otherwise GradScaler.unscale_
would attempt to access the parameters' .grad
s, which have probably already been reduce-scattered and set to None
. The only way I can see to avoid this is to disable overlapping grad reduce-scatters with backward compute.
But there's also the issue when _step_supports_amp_scaling
set, GradScaler.unscale_
will never be called from PyTorch or PyTorch Lightning. I saw you tried to unscale at here: nlp_overrides.py#L1202, but this function was never called, I confirmed with Stable Diffusion and LLM.
Overlapping reduce-scatter with bprop is quite important to the performance, so I think it is necessary.
I see, we need _step_support_amp_scaling=False
specifically when using nemo.collections.nlp.parts.nlp_overrides.GradScaler
. However, _step_support_amp_scaling=True
is needed for correct behavior with torch.amp.GradScaler
. I think the cleanest solution is to set _step_support_amp_scaling=False
in NeMo's distopt wrapper. That helps keep the NeMo-specific logic separate from the general PyTorch logic in Apex. Reverting the changes to the grad clipping logic (e.g. with https://github.com/timmoon10/apex/commit/0fa8e3a3a231e655f7e1690dfaf9f2f7c53be4e6) is needed to preserve correct behavior with torch.amp.GradScaler
.
This PR enhances distributed fused adam by:
@timmoon10 @crcrpar Please help review, thanks.