After the recent update to PyTorch 2.3.0, petals encounters an import error when using the torch.cuda.amp.grad_scaler module. The specific error message is:
ImportError: cannot import name '_refresh_per_optimizer_state' from 'torch.cuda.amp.grad_scaler' site-packages/torch/cuda/amp/grad_scaler.py)
The issue is due to changes in the new PyTorch version that are currently incompatible with current codebase.
Quick Workaround
Resolve this issue, by simply reverting the PyTorch version specified in setup.cfg from 'torch>=1.12' to 'torch==2.2.2', which is the last version known to work without this problem. Just so it works stably while I investigate changes in new pytorch and make the codebase to be compatible with PyTorch 2.3.0 or later.
Problem Description
After the recent update to PyTorch 2.3.0, petals encounters an import error when using the
torch.cuda.amp.grad_scaler
module. The specific error message is:ImportError: cannot import name '_refresh_per_optimizer_state' from 'torch.cuda.amp.grad_scaler' site-packages/torch/cuda/amp/grad_scaler.py)
The issue is due to changes in the new PyTorch version that are currently incompatible with current codebase.
Quick Workaround
Resolve this issue, by simply reverting the PyTorch version specified in
setup.cfg
from 'torch>=1.12' to 'torch==2.2.2', which is the last version known to work without this problem. Just so it works stably while I investigate changes in new pytorch and make the codebase to be compatible with PyTorch 2.3.0 or later.Steps to Reproduce
Information
Action
Submitting a pull request to modifying the
install_requires
in oursetup.cfg
as described above, pending team feedback on this issue.