NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.43k stars 1.41k forks source link

fix groupnorm int32 index overflow #1845

Open tlogn opened 2 months ago

tlogn commented 2 months ago

This PR fix groupnorm int32 index calculate overflow when hwc is large, as hwc is of int data type. The problem could be reproduced by code below. @crcrpar please review, thanks !

from apex.contrib.group_norm import GroupNorm as ApexGroupNorm
import torch
layer = ApexGroupNorm(32, 128, dtype=torch.bfloat16, device='cuda', act='silu')

x = torch.randn(1, 128, 16128, 1200, dtype=torch.bfloat16, device='cuda').to(memory_format=torch.channels_last)

o = layer(x)
print(o[0][0][0][0])
tlogn commented 1 month ago

@crcrpar thanks for your advise. I've added a test case and tested with H100.

with the main branch, error occurs:

Mismatched elements: 1866930499 / 2477260800 (75.4%)
Greatest absolute difference: nan at index (0, 0, 2146, 1184) (up to 0.04 allowed)
Greatest relative difference: nan at index (0, 0, 2146, 1184) (up to 0 allowed)

----------------------------------------------------------------------
Ran 16 tests in 8.897s

FAILED (failures=1)

with the fixed branch, the atol should be adjusted to 7e-2 to account for potential increased reduction accuracy error

................
----------------------------------------------------------------------
Ran 16 tests in 8.079s

OK
tlogn commented 1 week ago

@crcrpar Hi there, is there any problem remained ?