pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 482 forks source link

PyTorch/XLA compatible with AMD GPUs? #6255

Open clintg6 opened 11 months ago

clintg6 commented 11 months ago

Questions

Does the current version of PyTorch/XLA offer support for ROCm with AMD GPUs?

JackCaoG commented 11 months ago

AFAIK we didn't test AMD GPU, it is supported by XLA:GPU but we might need a separate build or something. @vanbasten23 should have more info.

clintg6 commented 11 months ago

Seems like it could since it's listed as a deviceType in device.h. However, when I look into the XLA:GPU guide I only see a mention of CUDA. Similarly, I don't see a ROCm wheel or docker image in the official list.

vanbasten23 commented 11 months ago

Jack is right. Afaict, we never test AMD GPU at this point. And we need a special build for rocm. (currently we only support XLA_CUDA=1 flag)

clintg6 commented 11 months ago

@vanbasten2 What would we need to do build it? Is there anything I can do to facilitate it?

vanbasten23 commented 11 months ago

For cuda build, we built torch_xla such as XLA_CUDA=1 python setup.py install and it adds --config=cuda https://github.com/pytorch/xla/blob/64d5807b9b6399522081a725b438b8276b7d7aa2/setup.py#L285 when it build openXLA.

So similarly for rocm, I would imagine we need a new flag in torch_xla such as XLA_ROCM. When set, it should set --config=rocm and pass it to openXLA.

clintg6 commented 11 months ago

I test this and report back thanks!

miladm commented 11 months ago

Thanks for offering your contribution @clintg6

fengyang0317 commented 6 days ago

I created a torch_xla_rocm_plugin similar to the cuda plugin and compile it using --config=rocm. The program can find the ROCM xla devices. Tensors can be placed to the ROCM xla device.

But when running a sample training job, it failed. I have put /opt/rocm/lib/llvm/bin to PATH, which contains ld.lld.

python test_train_spmd_imagenet.py --fake_data
==> Preparing data..
Epoch 1 train begin 13:18:36
Traceback (most recent call last):
  File "/home/aiscuser/xla/test/spmd/test_train_spmd_imagenet.py", line 389, in <module>
    accuracy = train_imagenet()
  File "/home/aiscuser/xla/test/spmd/test_train_spmd_imagenet.py", line 364, in train_imagenet
    train_loop_fn(train_loader, epoch)
  File "/home/aiscuser/xla/test/spmd/test_train_spmd_imagenet.py", line 318, in train_loop_fn
    with xp.StepTrace('train_imagenet'):
  File "/home/aiscuser/.conda/envs/xla/lib/python3.10/site-packages/torch_xla/debug/profiler.py", line 170, in __exit__
    xm.mark_step()
  File "/home/aiscuser/.conda/envs/xla/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1046, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: INTERNAL: unable to find ld.lld in PATH: No such file or directory
malloc(): unsorted double linked list corrupted
Aborted
fengyang0317 commented 5 days ago

After setting LLVM_PATH, it worked.