Previously various tests in test_lamb.py were using the custom get_max_diff comparison function, which was causing numerical mismatches to appear in CI. This PR updates them to use the much more standard torch.testing.all_close, after which they are cleanly passing CI.
Previously various tests in test_lamb.py were using the custom get_max_diff comparison function, which was causing numerical mismatches to appear in CI. This PR updates them to use the much more standard torch.testing.all_close, after which they are cleanly passing CI.