pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 452 forks source link

Failed to import torch_xla by following the GPU instructions on an H100 node (A3-High) #7720

Open hosseinsarshar opened 1 month ago

hosseinsarshar commented 1 month ago

🐛 Bug

I followed the instructions provided in the README.md file and the GPU instructions provided in this file and I failed to properly import the torch_xla module.

I finally managed to fix the issue via a hacky approach - will share the steps below:

To Reproduce

Environment

Machine: A3-high Image project: ml-images Image family: tf-2-15-gpu-debian-11 Python runtime: Python 3.10 (a new conda environment is created for this) Cuda: 12.1 Nccl: 2.20.5

Followed these steps:

conda create -n pyxla python=3.10 -y
conda activate pyxla

echo "export PATH=\$PATH:/usr/local/cuda-12.1/bin" >> ~/.bashrc
echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> ~/.bashrc
source ~/.bashrc
conda activate pyxla

Then installed the required libraries mentioned in here and in here.

pip3 install torch==2.3.0
# GPU whl for python 3.10 + cuda 12.1
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl

pip install torch~=2.3.0 torch_xla~=2.3.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.3.0-py3-none-any.whl

Adding the env variables

export XLA_REGISTER_INSTALLED_PLUGINS=1
export GPU_NUM_DEVICES=1 
export PJRT_DEVICE=CUDA

Running a simple import:

python -c "import torch_xla as xla; import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices())"

Fails with:

# ***
## File "/opt/conda/envs/pyxla/lib/python3.10/site-packages/torch_xla/__init__.py", line 7, in <module>
##   import _XLAC
## ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory
# ***

I did copy the libpython3.10.so.1.0 file to /usr/lib - it fixed the issue:

sudo cp /opt/conda/envs/pyxla/lib/libpython3.10.so.1.0 /usr/lib

But then ran into a dependancy error for numpy which was not mentioned in any of the requirements. After installing the numpy package, I managed to run successfully fetch the cuda devices:

python -c "import torch_xla as xla; import torch_xla.core.xla_model as xm; 
print(xm.get_xla_supported_devices())"

# ['xla:0', 'xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7']

In my opinion, this can be simplified and some of the setup steps can be added to the torch-xla setup.

zpcore commented 1 month ago

If you are in a conda environment, you may need to add LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/lib:/usr/local/lib:${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib" to the cmdline.