pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.26k stars 115 forks source link

ImportError in LLaMA Training Script #412

Open viai957 opened 1 week ago

viai957 commented 1 week ago

When attempting to run the training script for LLaMA with the following command: CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh an ImportError is encountered. The specific error message is: ImportError: cannot import name 'Partial' from 'torch.distributed._tensor' (/apps/torchtitan/torchtitan/lib/python3.10/site-packages/torch/distributed/_tensor/__init__.py)

The training script should start without any import errors and utilize the specified configuration file to train the model across 8 GPUs.

The script fails to run due to an ImportError indicating that Partial cannot be imported from torch.distributed._tensor. The error traceback is as follows: Traceback (most recent call last): File "/apps/torchtitan/train.py", line 34, in <module> from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config File "/apps/torchtitan/torchtitan/models/__init__.py", line 7, in <module> from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer File "/apps/torchtitan/torchtitan/models/llama/__init__.py", line 10, in <module> from torchtitan.models.llama.model import ModelArgs, Transformer File "/apps/torchtitan/torchtitan/models/llama/model.py", line 17, in <module> from torchtitan.models.norms import create_norm File "/apps/torchtitan/torchtitan/models/norms.py", line 17, in <module> from torch.distributed._tensor import Partial, Replicate, Shard ImportError: cannot import name 'Partial' from 'torch.distributed._tensor' (/apps/torchtitan/torchtitan/lib/python3.10/site-packages/torch/distributed/_tensor/__init__.py)

kwen2501 commented 1 week ago

Partial used to be named as _Partial. That is, it was recently made public. You can upgrade your PyTorch version to pass this import error. Sorry about the break.

viai957 commented 1 week ago

I did try to upgrade the PyTorch to torch >= 2.3.1 but still the problem persists `

awgu commented 1 week ago

@viai957 I think unfortunately you need to use a nightly release, not simply torch >= 2.3.1. The challenge is that much of the code here in torchtitan is relying on changing code in the pytorch repo, so torchtitan generally requires a nightly version.

See the Preview (Nightly) option in https://pytorch.org/get-started/locally/.