Closed w4seunggyu closed 7 months ago
@will-cromar can you take a look?
To use torch.distributed
on TPU v2/v3, we have to actually patch torch
at runtime to make it work with multithreading, which is not done by default. You'll have to add import torch_xla.experimental.pjrt_backend
to do that. See our PJRT doc and our Kaggle example.
I discussed this with @JackCaoG offline. This patching is hacky, and we're not totally comfortable doing that by default. On the other hand, torch.distributed
won't work on TPU v2/v3 without it. So we need to do one of the following:
torch.distributed
by default when it's required and log a warningimport torch_xla.experimental.pjrt_backend
Oh, it was already descripted in doc. My apologies for missing that one.
With patch, the code works perfectly fine. Many thanks for your help!
🐛 Bug
I observed that trying to iterate through DataLoader which wraps IterDataPipe made the program freeze. It is kinda cautious because I found this behavior when using
torchdata==0.7.1
at that time and my misknowledge might have led this, but all what I experimented are related to PyTorch/XLA, so I leave an issue here.A stack trace after
KeyboardInterrupt
shows a bunch oflock.acquire()
inthreading
, and I know there's a difference between XRT and PJRT on how TPU cores are handled with processes and threads. Therefore, the root cause might not be PJRT, but everything is unsure.To Reproduce
Running the script above shows the program is never exited.
Expected behavior
The program exited after iterating through DataLoader.
Environment
I tested it with Cloud TPU VM v2-8, by running a Docker container on top of it.
Additional context
I also tried XRT(
2.1.0+xrt
) with changing the beginning of_mp_fn()
as:and it finished as I wanted.