pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.35k stars 440 forks source link

BUG: ImportError: cannot import name 'CPUOffloadPolicy' #1985

Closed tginart closed 1 week ago

tginart commented 1 week ago
(xenv-0017-tts) aginart@ip-10-1-89-181:~/dev/fun_projects/generation_projects/torchtune$ tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
Traceback (most recent call last):
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/bin/tune", line 8, in <module>
    sys.exit(main())
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/site-packages/torchtune/_cli/run.py", line 196, in _run_cmd
    self._run_single_device(args, is_builtin=is_builtin)
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/site-packages/torchtune/_cli/run.py", line 102, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/site-packages/recipes/lora_finetune_single_device.py", line 22, in <module>
    from torchtune import config, modules, training, utils
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/site-packages/torchtune/training/__init__.py", line 8, in <module>
    from torchtune.training._distributed import (
  File "/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/site-packages/torchtune/training/_distributed.py", line 17, in <module>
    from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
ImportError: cannot import name 'CPUOffloadPolicy' from 'torch.distributed._composable.fsdp' (/fsx/home/aginart/.conda/envs/xenv-0017-tts/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/__init__.py)
felipemello1 commented 1 week ago

hey @tginart , can you run "pip list" or "conda list" and share the torch version you are using? Its possible that you are using an older (or newer?) version. Installing the most recent torch, e.g. nightlies or 2.5.1 should solve it.

tginart commented 1 week ago

Pip list for broken env:

Package                  Version
------------------------ ------------
absl-py                  2.1.0
accelerate               1.1.0
aiohappyeyeballs         2.4.3
aiohttp                  3.10.10
aiosignal                1.3.1
annotated-types          0.7.0
antlr4-python3-runtime   4.9.3
anyio                    4.6.2.post1
asttokens                2.4.1
attrs                    24.2.0
blobfile                 3.0.0
certifi                  2024.8.30
chardet                  5.2.0
charset-normalizer       3.4.0
click                    8.1.7
colorama                 0.4.6
DataProperty             1.0.1
datasets                 3.1.0
datatrove                0.3.0
decorator                5.1.1
dill                     0.3.8
distro                   1.9.0
docker-pycreds           0.4.0
eval_type_backport       0.2.0
evaluate                 0.4.3
executing                2.1.0
filelock                 3.13.1
frozenlist               1.5.0
fsspec                   2024.2.0
gitdb                    4.0.11
GitPython                3.1.43
h11                      0.14.0
httpcore                 1.0.6
httpx                    0.27.2
huggingface-hub          0.26.2
humanize                 4.11.0
idna                     3.10
ipython                  8.29.0
jedi                     0.19.1
Jinja2                   3.1.3
jiter                    0.7.0
joblib                   1.4.2
jsonlines                4.0.0
lm_eval                  0.4.5
loguru                   0.7.2
lxml                     5.3.0
markdown-it-py           3.0.0
MarkupSafe               2.1.5
matplotlib-inline        0.1.7
mbstrdecoder             1.1.3
mdurl                    0.1.2
more-itertools           10.5.0
mpmath                   1.3.0
msgspec                  0.18.6
multidict                6.1.0
multiprocess             0.70.16
networkx                 3.2.1
ninja                    1.11.1.1
nltk                     3.9.1
numexpr                  2.10.1
numpy                    1.26.3
nvidia-cublas-cu12       12.4.5.8
nvidia-cuda-cupti-cu12   12.4.127
nvidia-cuda-nvrtc-cu12   12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12        9.1.0.70
nvidia-cufft-cu12        11.2.1.3
nvidia-curand-cu12       10.3.5.147
nvidia-cusolver-cu12     11.6.1.9
nvidia-cusparse-cu12     12.3.1.170
nvidia-nccl-cu12         2.21.5
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.4.127
objprint                 0.2.3
omegaconf                2.3.0
openai                   1.54.3
orjson                   3.10.11
packaging                24.1
pandas                   2.2.3
parso                    0.8.4
pathvalidate             3.2.1
peft                     0.13.2
pexpect                  4.9.0
pillow                   10.4.0
pip                      24.2
platformdirs             4.3.6
plotext                  5.3.2
portalocker              2.10.1
prompt_toolkit           3.0.48
propcache                0.2.0
protobuf                 5.28.3
psutil                   6.1.0
ptyprocess               0.7.0
pure_eval                0.2.3
pyarrow                  18.0.0
pybind11                 2.13.6
pycryptodomex            3.21.0
pydantic                 2.9.2
pydantic_core            2.23.4
Pygments                 2.18.0
pynvml                   11.5.3
pytablewriter            1.2.0
python-dateutil          2.9.0.post0
pytz                     2024.2
PyYAML                   6.0.2
regex                    2024.9.11
requests                 2.32.3
rich                     13.9.4
rouge_score              0.1.2
sacrebleu                2.4.3
safetensors              0.4.5
scikit-learn             1.5.2
scipy                    1.14.1
sentencepiece            0.2.0
sentry-sdk               2.18.0
setproctitle             1.3.3
setuptools               75.1.0
shellingham              1.5.4
six                      1.16.0
smmap                    5.0.1
sniffio                  1.3.1
sqlitedict               2.1.0
stack-data               0.6.3
sympy                    1.13.1
tabledata                1.3.3
tabulate                 0.9.0
tcolorpy                 0.1.6
threadpoolctl            3.5.0
tiktoken                 0.8.0
together                 1.3.3
tokenizers               0.20.3
torch                    2.5.1
torchao                  0.6.1+cu124
torchtune                0.3.1
torchvision              0.20.1
tqdm                     4.66.6
tqdm-multiprocess        0.0.11
traitlets                5.14.3
transformers             4.46.2
triton                   3.1.0
typepy                   1.3.2
typer                    0.12.5
typing_extensions        4.12.2
tzdata                   2024.2
urllib3                  2.2.3
viztracer                0.17.0
wandb                    0.18.5
wcwidth                  0.2.13
wheel                    0.44.0
word2number              1.1
xformers                 0.0.28.post2
xxhash                   3.5.0
yarl                     1.17.1
zstandard                0.23.0

Pip list for env that ended up working:

Package                Version
---------------------- -----------
aiohappyeyeballs       2.4.3
aiohttp                3.10.10
aiosignal              1.3.1
antlr4-python3-runtime 4.9.3
asttokens              2.4.1
attrs                  24.2.0
bitsandbytes           0.44.1
blobfile               3.0.0
Brotli                 1.0.9
certifi                2024.8.30
charset-normalizer     3.3.2
click                  8.1.7
datasets               3.1.0
decorator              5.1.1
dill                   0.3.8
docker-pycreds         0.4.0
executing              2.1.0
filelock               3.13.1
frozenlist             1.5.0
fsspec                 2024.9.0
gitdb                  4.0.11
GitPython              3.1.43
gmpy2                  2.1.2
huggingface-hub        0.26.2
idna                   3.7
ipython                8.29.0
jedi                   0.19.2
Jinja2                 3.1.4
lxml                   5.3.0
MarkupSafe             2.1.3
matplotlib-inline      0.1.7
mkl_fft                1.3.11
mkl_random             1.2.8
mkl-service            2.4.0
mpmath                 1.3.0
multidict              6.1.0
multiprocess           0.70.16
networkx               3.2.1
numpy                  2.0.1
omegaconf              2.3.0
packaging              24.2
pandas                 2.2.3
parso                  0.8.4
pexpect                4.9.0
pillow                 10.4.0
pip                    24.2
platformdirs           4.3.6
prompt_toolkit         3.0.48
propcache              0.2.0
protobuf               5.28.3
psutil                 6.1.0
ptyprocess             0.7.0
pure_eval              0.2.3
pyarrow                18.0.0
pycryptodomex          3.21.0
Pygments               2.18.0
PySocks                1.7.1
python-dateutil        2.9.0.post0
pytz                   2024.2
PyYAML                 6.0.2
regex                  2024.11.6
requests               2.32.3
safetensors            0.4.5
sentencepiece          0.2.0
sentry-sdk             2.18.0
setproctitle           1.3.3
setuptools             75.1.0
six                    1.16.0
smmap                  5.0.1
stack-data             0.6.3
sympy                  1.13.1
tiktoken               0.8.0
tokenizers             0.20.3
torch                  2.5.1
torchao                0.6.1
torchaudio             2.5.1
torchtune              0.3.1
torchvision            0.20.1
tqdm                   4.67.0
traitlets              5.14.3
transformers           4.46.2
triton                 3.1.0
typing_extensions      4.11.0
tzdata                 2024.2
urllib3                2.2.3
wandb                  0.18.6
wcwidth                0.2.13
wheel                  0.44.0
xxhash                 3.5.0
yarl                   1.17.1
felipemello1 commented 1 week ago

it seems to be solved. Please feel free to reopen if it isnt.