Open Mohamed-Dhouib opened 1 year ago
Let's do it! To be clear, this would be enabled with: Trainer(accelerator='cuda'|'gpu', strategy='xla')
@carmocca I assume we could reuse a lot of our current xla-strategy for tpus.
That would be part of the goal
I like it, and I think it won't even be that hard! The abstraction of strategy and accelerator are already in place and are meant to support exactly this kind of relationship between a communication layer (xla) and accelerator (gpu/tpu).
The first step towards this will be to simply rename our TPUSpawnStrategy
to XLAStrategy
(which is what we already planned to do and have done so already in lightning_lite
).
This is great! :rabbit:
Hello,this is very wonderful work! I want to know when we can finish it that Trainer(accelerator='cuda'|'gpu', strategy='xla') can work normally.
@qipengh We haven't started working on it. The feature is up for grabs if you or anyone from the community has interest in contributing and testing it out.
This should become very easy once we add support for XLA's PJRT runtime: https://github.com/pytorch/xla/blob/master/docs/pjrt.md#gpu
FYI @Liyang90 has a pr to add PJRT support in https://github.com/Lightning-AI/lightning/pull/17352
In addition, we need to land
Is there an example model on how to use XLA with (a single) CUDA GPU? The link above now 404s since it was posted, I am struggling to find one anywhere; currently everything I come across is for TPUs only.
Roughly how much work do folks think is still needed in order to implement this FR?
Description & Motivation
I've experienced with pytorch XLA using multitple NVIDIA A100 GPU and I observed that in most cases training is faster. So it would be really nice to have the option to use XLA for training in pytorch lightning.
The main advantage is faster training.
Additional context
Here is a code link : https://github.com/Dhouib-med/Test-XLA/blob/17e5b6bd6c77fffa67818462856277a57877ff3b/test_xla.py to train a simple CNN on the MNIST dataset using XLA (on 2 GPUS). The main parts where taken from https://github.com/pytorch/xla. This wheel needs to be installed along with adequate pytorch and torchvision versions (1.11 and 0.14) https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp37-cp37m-linux_x86_64.whl @justusschock
cc @borda @justusschock @awaelchli @carmocca @JackCaoG @steventk-g @Liyang90