bigscience-workshop / petals

🌸 Run LLMs at home, BitTorrent-style. Fine-tuning and inference up to 10x faster than offloading
https://petals.dev
MIT License
8.89k stars 490 forks source link

Error with PyTorch 2.3.0: Missing '_refresh_per_optimizer_state' in 'torch.cuda.amp.grad_scaler' #576

Closed Priyanshupareek closed 2 months ago

Priyanshupareek commented 2 months ago

Problem Description

After the recent update to PyTorch 2.3.0, petals encounters an import error when using the torch.cuda.amp.grad_scaler module. The specific error message is: ImportError: cannot import name '_refresh_per_optimizer_state' from 'torch.cuda.amp.grad_scaler' site-packages/torch/cuda/amp/grad_scaler.py)

The issue is due to changes in the new PyTorch version that are currently incompatible with current codebase.

Quick Workaround

Resolve this issue, by simply reverting the PyTorch version specified in setup.cfg from 'torch>=1.12' to 'torch==2.2.2', which is the last version known to work without this problem. Just so it works stably while I investigate changes in new pytorch and make the codebase to be compatible with PyTorch 2.3.0 or later.

Steps to Reproduce

  1. Simply install petals using pip as suggested.
  2. Run the the server.
  3. Observe the import error.

Information

Action

Submitting a pull request to modifying the install_requires in our setup.cfg as described above, pending team feedback on this issue.