NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.43k stars 1.4k forks source link

Add multi_tensor_unscale_l2norm_cuda #1727

Closed minitu closed 1 year ago

minitu commented 1 year ago

This PR adds multi_tensor_unscale_l2norm_cuda, which is used to fuse gradient unscaling (with AMP) and L2 norm computation of the gradients. To retain the original precision of the gradients (especially FP16), unscaling is only accounted for in the norm computation and is not applied to the gradients themselves.

nWEIdia commented 1 year ago

We manually verified this PR and it worked. Please go ahead merging this PR.