NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.33k stars 1.39k forks source link

Use torch.testing.all_close instead of get_max_diff in test_lamb.py #1806

Closed Fuzzkatt closed 3 months ago

Fuzzkatt commented 3 months ago

Previously various tests in test_lamb.py were using the custom get_max_diff comparison function, which was causing numerical mismatches to appear in CI. This PR updates them to use the much more standard torch.testing.all_close, after which they are cleanly passing CI.

Fuzzkatt commented 3 months ago

cc @eqy @crcrpar for review