pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

IterDataPipe + DataLoader hangs with PJRT #6815

Closed w4seunggyu closed 7 months ago

w4seunggyu commented 7 months ago

🐛 Bug

I observed that trying to iterate through DataLoader which wraps IterDataPipe made the program freeze. It is kinda cautious because I found this behavior when using torchdata==0.7.1 at that time and my misknowledge might have led this, but all what I experimented are related to PyTorch/XLA, so I leave an issue here.

A stack trace after KeyboardInterrupt shows a bunch of lock.acquire() in threading, and I know there's a difference between XRT and PJRT on how TPU cores are handled with processes and threads. Therefore, the root cause might not be PJRT, but everything is unsure.

To Reproduce

# env: PJRT_DEVICE='TPU'

import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(rank: int):
    torch.distributed.init_process_group('xla', init_method='xla://')

    dataset = torch.utils.data.datapipes.iter.IterableWrapper(range(20))
    train_dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=2,
        num_workers=1,
    )
    mp_device_loader = pl.MpDeviceLoader(train_dataloader, xm.xla_device())

    for num_iter, batch in enumerate(mp_device_loader):
        pass

    xm.rendezvous('')

if __name__ == '__main__':
    xmp.spawn(_mp_fn, args=())

Running the script above shows the program is never exited.

Expected behavior

The program exited after iterating through DataLoader.

Environment

I tested it with Cloud TPU VM v2-8, by running a Docker container on top of it.

Additional context

I also tried XRT(2.1.0+xrt) with changing the beginning of _mp_fn() as:

# env: XRT_TPU_CONFIG='localservice;0;localhost:51011'

def _mp_fn(rank: int):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    torch.distributed.init_process_group(
        'xla',
        init_method='env://',
        rank=xm.get_ordinal(),
        world_size=xm.xrt_world_size(),
    )

    ...

and it finished as I wanted.

JackCaoG commented 7 months ago

@will-cromar can you take a look?

will-cromar commented 7 months ago

To use torch.distributed on TPU v2/v3, we have to actually patch torch at runtime to make it work with multithreading, which is not done by default. You'll have to add import torch_xla.experimental.pjrt_backend to do that. See our PJRT doc and our Kaggle example.

I discussed this with @JackCaoG offline. This patching is hacky, and we're not totally comfortable doing that by default. On the other hand, torch.distributed won't work on TPU v2/v3 without it. So we need to do one of the following:

w4seunggyu commented 7 months ago

Oh, it was already descripted in doc. My apologies for missing that one.

With patch, the code works perfectly fine. Many thanks for your help!