ROCm / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
18 stars 16 forks source link

Unskip some unit tests related to issue #82 #98

Closed hubertlu-tw closed 1 year ago

hubertlu-tw commented 1 year ago

As the regression has been resolved from PyTorch 1.12.1 and above, the skipped unit tests related to https://github.com/ROCmSoftwarePlatform/apex/issues/82 can be unskipped.

test_adam_option (test_fused_optimizer.TestFusedAdam): due to https://github.com/pytorch/pytorch/issues/80809#issuecomment-1175211598 test_multi_device (test_fused_optimizer.TestFusedAdam): due to https://github.com/pytorch/pytorch/issues/80809#issuecomment-1175211598 test_float (test_fused_optimizer.TestFusedAdam): due to https://github.com/pytorch/pytorch/issues/80809#issuecomment-1175211598 test_bfloat16 (test_fused_optimizer.TestFusedAdam): due to https://github.com/pytorch/pytorch/issues/80809#issuecomment-1175211598

In addition, test_state_dict (test_checkpointing.TestCheckpointing) failed because of the following error:

Traceback (most recent call last):

  File "/apex/tests/L0/run_amp/test_checkpointing.py", line 252, in test_state_dict

    optimizer.step()

  File "/opt/conda/lib/python3.7/site-packages/torch/optim/optimizer.py", line 109, in wrapper

    return func(*args, **kwargs)

  File "/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context

    return func(*args, **kwargs)

  File "/opt/conda/lib/python3.7/site-packages/torch/optim/adam.py", line 171, in step

    capturable=group['capturable'])

  File "/opt/conda/lib/python3.7/site-packages/torch/optim/adam.py", line 226, in adam

    capturable=capturable)

  File "/opt/conda/lib/python3.7/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam

    assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors."

AssertionError: If capturable=False, state_steps should not be CUDA tensors.

We can resolve this error by ensuring torch.optim.Adam to use capturable=True in a CUDA graph.

hubertlu-tw commented 1 year ago

jenkins: retest this please