Addresses a hardware limitation edge case (e.g., for AMD ROCm GPUs) that occurs when trying to backpropagate gradients through a loss tensor containing more than 2e8 elements.
N.b., for NVIDIA GPUs such as A100s, this does not appear to be an issue, but for AMD GPUs such as the MI250x line, this threshold is required.
I've created a separate ROCm-related issue with pytorch regarding this edge case.
2e8
elements.pytorch
regarding this edge case.