NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.17k stars 1.35k forks source link

Add multi_tensor_unscale_l2norm_cuda #1727

Closed minitu closed 9 months ago

minitu commented 10 months ago

This PR adds multi_tensor_unscale_l2norm_cuda, which is used to fuse gradient unscaling (with AMP) and L2 norm computation of the gradients. To retain the original precision of the gradients (especially FP16), unscaling is only accounted for in the norm computation and is not applied to the gradients themselves.

nWEIdia commented 9 months ago

We manually verified this PR and it worked. Please go ahead merging this PR.