pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 462 forks source link

Thread hanging when running DDP runtime on MNIST example #5524

Closed humzaiqbal closed 1 year ago

humzaiqbal commented 1 year ago

🐛 Bug

While running the MNIST example using the DDP runtime I notice that it hangs when trying to initialize the process group likely due to multiprocessing issue

To Reproduce

Run this modified version of the test_train_mnist_mp.py script using the command

python3 test_train_mp_mnist.py --ddp --fake_data
import args_parse

MODEL_OPTS = {
    '--ddp': {
        'action': 'store_true',
    },
    '--pjrt_distributed': {
        'action': 'store_true',
    },
}

FLAGS = args_parse.parse_common_options(
    datadir='/tmp/mnist-data',
    batch_size=128,
    momentum=0.5,
    lr=0.01,
    target_accuracy=98.0,
    num_epochs=18,
    opts=MODEL_OPTS.items(),
)

import os
import shutil
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch_xla.distributed.xla_backend

class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def _train_update(device, step, loss, tracker, epoch, writer):
  test_utils.print_training_update(
      device,
      step,
      loss.item(),
      tracker.rate(),
      tracker.global_rate(),
      epoch,
      summary_writer=writer)

def train_mnist(flags, **kwargs):
  if flags.pjrt_distributed:
    import torch_xla.experimental.pjrt_backend
    dist.init_process_group('xla', init_method='pjrt://')
  elif flags.ddp:
    dist.init_process_group(
        'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal())

  torch.manual_seed(1)

  if flags.fake_data:
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(flags.batch_size, 1, 28,
                          28), torch.zeros(flags.batch_size,
                                           dtype=torch.int64)),
        sample_count=60000 // flags.batch_size // xm.xrt_world_size())
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(flags.batch_size, 1, 28,
                          28), torch.zeros(flags.batch_size,
                                           dtype=torch.int64)),
        sample_count=10000 // flags.batch_size // xm.xrt_world_size())
  else:
    train_dataset = datasets.MNIST(
        os.path.join(flags.datadir, str(xm.get_ordinal())),
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))]))
    test_dataset = datasets.MNIST(
        os.path.join(flags.datadir, str(xm.get_ordinal())),
        train=False,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))]))
    train_sampler = None
    if xm.xrt_world_size() > 1:
      train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=flags.batch_size,
        sampler=train_sampler,
        drop_last=flags.drop_last,
        shuffle=False if train_sampler else True,
        num_workers=flags.num_workers)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=flags.batch_size,
        drop_last=flags.drop_last,
        shuffle=False,
        num_workers=flags.num_workers)

  # Scale learning rate to num cores
  lr = flags.lr * xm.xrt_world_size()

  device = xm.xla_device()
  model = MNIST().to(device)

  # Initialization is nondeterministic with multiple threads in PjRt.
  # Synchronize model parameters across replicas manually.
  if flags.pjrt_distributed:
    xm.broadcast_master_param(model)

  if flags.ddp:
    model = DDP(model, gradient_as_bucket_view=True)
  writer = None
  if xm.is_master_ordinal():
    writer = test_utils.get_summary_writer(flags.logdir)
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
  loss_fn = nn.NLLLoss()

  def train_loop_fn(loader, epoch):
    tracker = xm.RateTracker()
    model.train()
    for step, (data, target) in enumerate(loader):
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      if flags.ddp:
        optimizer.step()
      else:
        xm.optimizer_step(optimizer)
      tracker.add(flags.batch_size)
      if step % flags.log_steps == 0:
        xm.add_step_closure(
            _train_update,
            args=(device, step, loss, tracker, epoch, writer),
            run_async=flags.async_closures)

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    model.eval()
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct.item() / total_samples
    accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
    return accuracy

  train_device_loader = pl.MpDeviceLoader(train_loader, device)
  test_device_loader = pl.MpDeviceLoader(test_loader, device)
  accuracy, max_accuracy = 0.0, 0.0
  for epoch in range(1, flags.num_epochs + 1):
    xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
    train_loop_fn(train_device_loader, epoch)
    xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

    accuracy = test_loop_fn(test_device_loader)
    xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
        epoch, test_utils.now(), accuracy))
    max_accuracy = max(accuracy, max_accuracy)
    test_utils.write_to_summary(
        writer,
        epoch,
        dict_to_write={'Accuracy/test': accuracy},
        write_xla_metrics=True)
    if flags.metrics_debug:
      xm.master_print(met.metrics_report())

  test_utils.close_summary_writer(writer)
  xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
  return max_accuracy

def _mp_fn(index, flags):
  os.environ['MASTER_ADDR'] = 'localhost'
  os.environ['MASTER_PORT'] = '12355'
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy = train_mnist(flags)
  if flags.tidy and os.path.isdir(flags.datadir):
    shutil.rmtree(flags.datadir)
  if accuracy < flags.target_accuracy:
    print('Accuracy {} is below target {}'.format(accuracy,
                                                  flags.target_accuracy))
    sys.exit(21)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)

Expected behavior

The script is able to use

Environment

Additional context

After the program hangs I exited and this is the stack trace I got

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
    replica_results = list(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 439, in result
    self._condition.wait(timeout)
  File "/usr/lib/python3.8/threading.py", line 302, in wait
    waiter.acquire()
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test_train_mp_mnist.py", line 231, in <module>
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
    _run_multiprocess(spawn_fn, start_method=start_method)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
    replica_results = list(
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 644, in __exit__
    self.shutdown(wait=True)
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 686, in shutdown
    self._queue_management_thread.join()
  File "/usr/lib/python3.8/threading.py", line 1011, in join
    self._wait_for_tstate_lock()
  File "/usr/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt
^CError in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 104, in _python_exit
Process SpawnProcess-4:
Process SpawnProcess-2:
Process SpawnProcess-1:
    t.join()
  File "/usr/lib/python3.8/threading.py", line 1011, in join
    self._wait_for_tstate_lock()
  File "/usr/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 97, in get
    res = self._recv_bytes()
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes
    buf = self._recv(4)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 96, in get
    with self._rlock:
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 96, in get
    with self._rlock:
  File "/usr/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/usr/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
KeyboardInterrupt
^CError in atexit._run_exitfuncs:
Error in atexit._run_exitfuncs:
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/popen_fork.py", line 27, in poll
Traceback (most recent call last):
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 40, in _python_exit
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 40, in _python_exit
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 40, in _python_exit
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt
    t.join()
    t.join()
  File "/usr/lib/python3.8/threading.py", line 1011, in join
  File "/usr/lib/python3.8/threading.py", line 1011, in join
    t.join()
  File "/usr/lib/python3.8/threading.py", line 1011, in join
    self._wait_for_tstate_lock()
    self._wait_for_tstate_lock()
    self._wait_for_tstate_lock()
  File "/usr/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
  File "/usr/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
  File "/usr/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt
    elif lock.acquire(block, timeout):
    elif lock.acquire(block, timeout):
KeyboardInterrupt
KeyboardInterrupt
terminate called without an active exception
https://symbolize.stripped_domain/r/?trace=7f05c690000b,7f05c690008f&map=
*** SIGABRT received by PID 282231 (TID 283863) on cpu 29 from PID 282231; stack trace: ***
PC: @     0x7f05c690000b  (unknown)  raise
    @     0x7f0486e19a1a       1152  (unknown)
    @     0x7f05c6900090  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f05c690000b,7f0486e19a19,7f05c690008f&map=ceee8fa20ddf9c34af43f587221e91de:7f0479ef1000-7f0487030840
E0830 04:14:37.610973  283863 coredump_hook.cc:414] RAW: Remote crash data gathering hook invoked.
E0830 04:14:37.610994  283863 coredump_hook.cc:453] RAW: Skipping coredump since rlimit was 0 at process start.
E0830 04:14:37.611009  283863 client.cc:278] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0830 04:14:37.611016  283863 coredump_hook.cc:512] RAW: Sending fingerprint to remote end.
E0830 04:14:37.611024  283863 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0830 04:14:37.611032  283863 coredump_hook.cc:518] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0830 04:14:37.611040  283863 coredump_hook.cc:580] RAW: Dumping core locally.

humzaiqbal@t1v-n-5fec587b-w-0:~/xla/test$ E0830 04:14:37.805479  283863 process_state.cc:784] RAW: Raising signal 6 with default behavior

which seems to to suggest some sort of multiprocessing issue. Based on some print statements it seems to struggle when this line is called

dist.init_process_group(
        'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal())
humzaiqbal commented 1 year ago

Seems that this was just a case of XLA needing to compile the computation graph which takes some initial startup time.