Open clintg6 opened 11 months ago
AFAIK we didn't test AMD GPU, it is supported by XLA:GPU but we might need a separate build or something. @vanbasten23 should have more info.
Seems like it could since it's listed as a deviceType in device.h. However, when I look into the XLA:GPU guide I only see a mention of CUDA. Similarly, I don't see a ROCm wheel or docker image in the official list.
Jack is right. Afaict, we never test AMD GPU at this point. And we need a special build for rocm. (currently we only support XLA_CUDA=1 flag)
@vanbasten2 What would we need to do build it? Is there anything I can do to facilitate it?
For cuda build, we built torch_xla such as XLA_CUDA=1 python setup.py install
and it adds --config=cuda
https://github.com/pytorch/xla/blob/64d5807b9b6399522081a725b438b8276b7d7aa2/setup.py#L285 when it build openXLA.
So similarly for rocm, I would imagine we need a new flag in torch_xla such as XLA_ROCM
. When set, it should set --config=rocm and pass it to openXLA.
I test this and report back thanks!
Thanks for offering your contribution @clintg6
I created a torch_xla_rocm_plugin
similar to the cuda plugin and compile it using --config=rocm
. The program can find the ROCM xla devices. Tensors can be placed to the ROCM xla device.
But when running a sample training job, it failed. I have put /opt/rocm/lib/llvm/bin to PATH, which contains ld.lld.
python test_train_spmd_imagenet.py --fake_data
==> Preparing data..
Epoch 1 train begin 13:18:36
Traceback (most recent call last):
File "/home/aiscuser/xla/test/spmd/test_train_spmd_imagenet.py", line 389, in <module>
accuracy = train_imagenet()
File "/home/aiscuser/xla/test/spmd/test_train_spmd_imagenet.py", line 364, in train_imagenet
train_loop_fn(train_loader, epoch)
File "/home/aiscuser/xla/test/spmd/test_train_spmd_imagenet.py", line 318, in train_loop_fn
with xp.StepTrace('train_imagenet'):
File "/home/aiscuser/.conda/envs/xla/lib/python3.10/site-packages/torch_xla/debug/profiler.py", line 170, in __exit__
xm.mark_step()
File "/home/aiscuser/.conda/envs/xla/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1046, in mark_step
torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: INTERNAL: unable to find ld.lld in PATH: No such file or directory
malloc(): unsorted double linked list corrupted
Aborted
After setting LLVM_PATH
, it worked.
Questions
Does the current version of PyTorch/XLA offer support for ROCm with AMD GPUs?