allenai / OLMo

Modeling, training, eval, and inference code for OLMo
https://allenai.org/olmo
Apache License 2.0
4.63k stars 470 forks source link

No module named 'torch.distributed.device_mesh' #559

Closed prakamya-mishra closed 6 months ago

prakamya-mishra commented 6 months ago

🐛 Describe the bug

@2015aroras I used the recent release tag to clone the repo using: git clone https://github.com/allenai/OLMo.git --branch=v0.3.0

While running the training script on a single node of MI250, I am getting the following error:

 (olmo) root@node:/dockerx/Projects/Experiments/OLMO# source train.sh --node-count="1" --config="OLMo-1B"
 Traceback (most recent call last):
   File "/dockerx/Repositories/OLMo/scripts/train.py", line 14, in <module>
     from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
 ModuleNotFoundError: No module named 'torch.distributed.device_mesh'

train.sh:

 OLMO_REPO='/dockerx/Repositories/OLMo'
 export HF_DATASETS_CACHE="/shared/hf_cache/"

 for arg in "$@"; do
     case "$arg" in
         --node-count=*) NODE_COUNT="${arg#*=}" ;;
         --config=*) CONFIG="${arg#*=}" ;;
     esac
 done

 HSA_FORCE_FINE_GRAIN_PCIE=1 OMP_NUM_THREADS=128 NCCL_DEBUG=INFO NCCL_SOCKET_IFNAME=eth NCCL_P2P_DISABLE=1 torchrun --nproc_per_node=8 $OLMO_REPO/scripts/train.py $OLMO_REPO/configs/$CONFIG.yaml

Environment setup:

  1. pull the docker image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2
  2. Start the container and create the olmo environment:
    1. conda create --name olmo --clone py_3.10
    2. cd <OLMo v0.3.0 git clone repo>
    3. pip install -e .[all]

Versions

Python 3.10.13 absl-py==2.1.0 -e git+https://github.com/allenai/OLMo.git@a52d0584cacbf7facf43eb2492a7c3247136ff72#egg=ai2_olmo aiohttp==3.9.3 aiosignal==1.3.1 annotated-types==0.6.0 antlr4-python3-runtime==4.9.3 apex @ file:///var/lib/jenkins/apex appdirs==1.4.4 asgiref==3.7.2 astunparse==1.6.3 async-timeout==4.0.3 attrs==23.2.0 audioread==3.0.1 backports.tarfile==1.1.1 beaker-gantry==0.22.3 beaker-py==1.26.6 black==23.12.1 boltons==24.0.0 boto3==1.34.92 botocore==1.34.92 build==1.2.1 cached_path==1.6.2 cachetools==5.3.2 certifi==2024.2.2 cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 click-help-colors==0.9.4 colorama==0.4.6 coremltools==5.0b5 cryptography==42.0.5 Cython==3.0.8 datasets==2.19.0 decorator==5.1.1 dill==0.3.8 Django==5.0.1 docker==6.1.3 docker-pycreds==0.4.0 docutils==0.21.2 exceptiongroup==1.2.0 execnet==2.0.2 expecttest==0.1.6 face==20.1.1 filelock==3.9.0 flatbuffers==2.0 frozenlist==1.4.1 fsspec==2024.3.1 ftfy==6.2.0 future==0.18.3 geojson==2.5.0 ghstack==0.7.1 gitdb==4.0.11 GitPython==3.1.43 glom==23.5.0 google-api-core==2.18.0 google-auth==2.27.0 google-auth-oauthlib==1.0.0 google-cloud-core==2.4.1 google-cloud-storage==2.16.0 google-crc32c==1.5.0 google-resumable-media==2.7.0 googleapis-common-protos==1.63.0 grpcio==1.60.1 huggingface-hub==0.21.4 hypothesis==5.35.1 idna==3.6 image==1.5.33 imageio==2.33.1 importlib_metadata==7.1.0 iniconfig==2.0.0 isort==5.12.0 jaraco.classes==3.4.0 jaraco.context==5.3.0 jaraco.functools==4.0.1 jeepney==0.8.0 Jinja2==3.1.2 jmespath==0.10.0 joblib==1.3.2 junitparser==2.1.1 keyring==25.1.0 lazy_loader==0.3 librosa==0.10.1 lightning-utilities==0.11.2 lintrunner==0.10.7 llvmlite==0.38.1 lxml==5.1.0 Markdown==3.5.2 markdown-it-py==3.0.0 MarkupSafe==2.1.5 mdurl==0.1.2 mkl-fft==1.3.1 mkl-random @ file:///home/builder/ci_310/mkl_random_1641843545607/work mkl-service==2.4.0 more-itertools==10.2.0 mpmath==1.3.0 msgpack==1.0.7 msgspec==0.18.6 multidict==6.0.5 multiprocess==0.70.16 mypy==1.3.0 mypy-extensions==1.0.0 necessary==0.4.3 networkx==2.8.8 nh3==0.2.17 numba==0.55.2 numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1653325310407/work oauthlib==3.2.2 omegaconf==2.3.0 opt-einsum==3.3.0 optionloop==1.0.7 packaging==23.2 pandas==2.2.2 pathspec==0.12.1 petname==2.6 Pillow==9.5.0 pkginfo==1.10.0 platformdirs==4.2.0 pluggy==1.4.0 pooch==1.8.0 proto-plus==1.23.0 protobuf==3.20.2 psutil==5.9.8 pyarrow==16.0.0 pyarrow-hotfix==0.6 pyasn1==0.5.1 pyasn1-modules==0.3.0 pycparser==2.21 pydantic==2.7.1 pydantic_core==2.18.2 Pygments==2.15.0 pyproject_hooks==1.0.0 pytest==8.1.1 pytest-cpp==2.3.0 pytest-flakefinder==1.1.0 pytest-rerunfailures==13.0 pytest-shard==0.1.2 pytest-sphinx==0.6.3 pytest-xdist==3.3.1 python-dateutil==2.8.2 pytz==2024.1 PyWavelets==1.5.0 PyYAML @ file:///croot/pyyaml_1698096049011/work readme_renderer==43.0 regex==2024.4.16 requests==2.31.0 requests-oauthlib==1.3.1 requests-toolbelt==1.0.0 requirements-parser==0.9.0 rfc3986==2.0.0 rich==13.7.1 rockset==1.0.3 rsa==4.9 ruff==0.4.2 s3transfer==0.10.1 safetensors==0.4.3 scikit-image==0.20.0 scikit-learn==1.4.0 scipy==1.8.1 SecretStorage==3.3.3 sentry-sdk==2.0.0 setproctitle==1.3.3 six @ file:///tmp/build/80754af9/six_1644875935023/work smart-open==7.0.4 smashed==0.21.5 smmap==5.0.1 sortedcontainers==2.4.0 soundfile==0.12.1 soxr==0.3.7 sqlparse==0.4.4 sympy==1.12 tb-nightly==2.13.0a20230426 tensorboard==2.13.0 tensorboard-data-server==0.7.2 threadpoolctl==3.2.0 tifffile==2024.1.30 tokenizers==0.19.1 tomli==2.0.1 torch @ file:///var/lib/jenkins/pytorch/dist/torch-2.1.2%2Bgit98a6632-cp310-cp310-linux_x86_64.whl#sha256=ef76dfb10d50367e9dd72b0ec9725a5b2922636adaeca5e26b3c3be419085586 torchmetrics==1.3.2 torchvision==0.16.1+fdea156 tqdm==4.66.1 transformers==4.40.1 triton @ git+https://github.com/ROCmSoftwarePlatform/triton@066e7deb0b8394c73ca67a4d31cde708018b1f13#subdirectory=python trouting==0.3.3 twine==5.0.0 types-setuptools==69.5.0.20240423 typing_extensions==4.9.0 tzdata==2024.1 unittest-xml-reporting==3.2.0 urllib3==1.26.18 wandb==0.16.6 wcwidth==0.2.13 websocket-client==1.8.0 Werkzeug==3.0.1 wrapt==1.16.0 xdoctest==1.1.0 xxhash==3.4.1 yarl==1.9.4 z3-solver==4.12.2.0 zipp==3.17.0

dumitrac commented 6 months ago

@prakamya-mishra I see that device_mesh requires Pytorch 2.2 (link). Are you able to upgrade yours from 2.1.2 and confirm if that resolves it?

2015aroras commented 6 months ago

@dumitrac I think we need to pull https://github.com/allenai/OLMo/commit/885a1f0d41a51a5528617be0195804bf40a0a114 into main to fix this.

prakamya-mishra commented 6 months ago

@dumitrac @2015aroras I am using the rocm/pytorch docker image from docker hub, and I could not find any pytorch 2.2 image tag. Is there a release of pytorch 2.2 for rocm?

dumitrac commented 6 months ago

I couldn't find a docker image for pytorch 2.2 for rocm either. Then, it looks like @2015aroras 's PR is the quickest way to unblock you (#561). Thank you both.

prakamya-mishra commented 6 months ago

Thank you! I will reach out to ROCm team to get an estimate on when they plan to release this pytorch 2.2 image.

prakamya-mishra commented 6 months ago

Sorry I will close the issue once the PR has been merged and I am able to verify that it works.