NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.42k stars 1.4k forks source link

Enhance Distributed Fused Adam #1794

Closed alpha0422 closed 7 months ago

alpha0422 commented 7 months ago

This PR enhances distributed fused adam by:

  1. Support NHWC layout (required by some Conv related models, e.g. Diffusion models);
  2. Fix the gradient clipping bug;
  3. Support CUDA graph;

@timmoon10 @crcrpar Please help review, thanks.

alpha0422 commented 7 months ago

@timmoon10, I like the idea of drop-in optimizer replacement. Right now, distributed fused adam sets _step_supports_amp_scaling, so scaler.unscale_(optim) or optim.unscale_grads(grad_scaler=scaler) won't be called from PyTorch or PyTorch Lightning, because the assumption of _step_supports_amp_scaling is the gradient unscaling will be done in the optimizer step function, thus gradient clipping need to be delayed to the optimizer step function too.

To support the idea you mentioned, I need _step_supports_amp_scaling need to be removed, but then I think it will break other use cases, and it will decrease the performance because gradient unscaling is explicit and not fused with the step kernel.

timmoon10 commented 7 months ago

I've implemented my proposed API at https://github.com/timmoon10/apex/commit/0fa8e3a3a231e655f7e1690dfaf9f2f7c53be4e6, although I haven't been able to test yet.

NeMo GPT avoided these issues because it implemented a custom GradScaler that called DistributedFusedAdam.unscale_grads within GradScaler.unscale_: https://github.com/NVIDIA/NeMo/blob/c5738263d8b4bedb0957374116d3e90746a51c37/nemo/collections/nlp/parts/nlp_overrides.py#L1235. See https://github.com/NVIDIA/apex/pull/1512 and https://github.com/NVIDIA/NeMo/pull/4900.

_step_supports_amp_scaling is needed because otherwise GradScaler.unscale_ would attempt to access the parameters' .grads, which have probably already been reduce-scattered and set to None. The only way I can see to avoid this is to disable overlapping grad reduce-scatters with backward compute.

alpha0422 commented 7 months ago

But there's also the issue when _step_supports_amp_scaling set, GradScaler.unscale_ will never be called from PyTorch or PyTorch Lightning. I saw you tried to unscale at here: nlp_overrides.py#L1202, but this function was never called, I confirmed with Stable Diffusion and LLM.

Overlapping reduce-scatter with bprop is quite important to the performance, so I think it is necessary.

timmoon10 commented 7 months ago

I see, we need _step_support_amp_scaling=False specifically when using nemo.collections.nlp.parts.nlp_overrides.GradScaler. However, _step_support_amp_scaling=True is needed for correct behavior with torch.amp.GradScaler. I think the cleanest solution is to set _step_support_amp_scaling=False in NeMo's distopt wrapper. That helps keep the NeMo-specific logic separate from the general PyTorch logic in Apex. Reverting the changes to the grad clipping logic (e.g. with https://github.com/timmoon10/apex/commit/0fa8e3a3a231e655f7e1690dfaf9f2f7c53be4e6) is needed to preserve correct behavior with torch.amp.GradScaler.