NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.33k stars 1.39k forks source link

deprecate uses of torch.cuda.amp #1813

Closed Fuzzkatt closed 3 months ago

Fuzzkatt commented 3 months ago

Deprecates usage of torch.cuda.amp across apex codebase

Fuzzkatt commented 3 months ago

Just tested all of the changed test files on 24.06 ngc container. All are fine except test_conv_bias_relu.py; torch.amp.custom_fwd and custom_bwd are pretty recent (https://github.com/pytorch/pytorch/pull/126531) and didn't make the 24.06 pytorch submodule cut, but they are merged into upstream.

Fuzzkatt commented 3 months ago

Retested on latest nightly container, test_conv_bias_relu.py is passing too.