jolibrain / joliGEN

Generative AI Image Toolset with GANs and Diffusion for Real-World Applications
https://www.joligen.com
Other
226 stars 31 forks source link

Using TPUs #554

Open hsleiman1 opened 9 months ago

hsleiman1 commented 9 months ago

Hello,

If we plan to use TPUs instead of GPUs, is it possible with the current config or shall we use a different configuration?

Thanks

beniz commented 9 months ago

Hi, my understanding from Pytorch/Google TPU doc is that it requires importing XLA and creating a device. So I believe the devic

# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm

and

device = xm.xla_device()

Then change the device here: https://github.com/jolibrain/joliGEN/blob/master/models/base_model.py#L87 It's also certainly needed to block certain calls under the use_cuda config calls in train.py and models/base_model.py.

We can look at it, good feature to have!