While running the MNIST example using the DDP runtime I notice that it hangs when trying to initialize the process group likely due to multiprocessing issue
To Reproduce
Run this modified version of the test_train_mnist_mp.py script using the command
python3 test_train_mp_mnist.py --ddp --fake_data
import args_parse
MODEL_OPTS = {
'--ddp': {
'action': 'store_true',
},
'--pjrt_distributed': {
'action': 'store_true',
},
}
FLAGS = args_parse.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=128,
momentum=0.5,
lr=0.01,
target_accuracy=98.0,
num_epochs=18,
opts=MODEL_OPTS.items(),
)
import os
import shutil
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch_xla.distributed.xla_backend
class MNIST(nn.Module):
def __init__(self):
super(MNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn2 = nn.BatchNorm2d(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def _train_update(device, step, loss, tracker, epoch, writer):
test_utils.print_training_update(
device,
step,
loss.item(),
tracker.rate(),
tracker.global_rate(),
epoch,
summary_writer=writer)
def train_mnist(flags, **kwargs):
if flags.pjrt_distributed:
import torch_xla.experimental.pjrt_backend
dist.init_process_group('xla', init_method='pjrt://')
elif flags.ddp:
dist.init_process_group(
'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal())
torch.manual_seed(1)
if flags.fake_data:
train_loader = xu.SampleGenerator(
data=(torch.zeros(flags.batch_size, 1, 28,
28), torch.zeros(flags.batch_size,
dtype=torch.int64)),
sample_count=60000 // flags.batch_size // xm.xrt_world_size())
test_loader = xu.SampleGenerator(
data=(torch.zeros(flags.batch_size, 1, 28,
28), torch.zeros(flags.batch_size,
dtype=torch.int64)),
sample_count=10000 // flags.batch_size // xm.xrt_world_size())
else:
train_dataset = datasets.MNIST(
os.path.join(flags.datadir, str(xm.get_ordinal())),
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset = datasets.MNIST(
os.path.join(flags.datadir, str(xm.get_ordinal())),
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
train_sampler = None
if xm.xrt_world_size() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=flags.batch_size,
sampler=train_sampler,
drop_last=flags.drop_last,
shuffle=False if train_sampler else True,
num_workers=flags.num_workers)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=flags.batch_size,
drop_last=flags.drop_last,
shuffle=False,
num_workers=flags.num_workers)
# Scale learning rate to num cores
lr = flags.lr * xm.xrt_world_size()
device = xm.xla_device()
model = MNIST().to(device)
# Initialization is nondeterministic with multiple threads in PjRt.
# Synchronize model parameters across replicas manually.
if flags.pjrt_distributed:
xm.broadcast_master_param(model)
if flags.ddp:
model = DDP(model, gradient_as_bucket_view=True)
writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(flags.logdir)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
loss_fn = nn.NLLLoss()
def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
model.train()
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
if flags.ddp:
optimizer.step()
else:
xm.optimizer_step(optimizer)
tracker.add(flags.batch_size)
if step % flags.log_steps == 0:
xm.add_step_closure(
_train_update,
args=(device, step, loss, tracker, epoch, writer),
run_async=flags.async_closures)
def test_loop_fn(loader):
total_samples = 0
correct = 0
model.eval()
for data, target in loader:
output = model(data)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum()
total_samples += data.size()[0]
accuracy = 100.0 * correct.item() / total_samples
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
return accuracy
train_device_loader = pl.MpDeviceLoader(train_loader, device)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
accuracy, max_accuracy = 0.0, 0.0
for epoch in range(1, flags.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
train_loop_fn(train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
accuracy = test_loop_fn(test_device_loader)
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
epoch, test_utils.now(), accuracy))
max_accuracy = max(accuracy, max_accuracy)
test_utils.write_to_summary(
writer,
epoch,
dict_to_write={'Accuracy/test': accuracy},
write_xla_metrics=True)
if flags.metrics_debug:
xm.master_print(met.metrics_report())
test_utils.close_summary_writer(writer)
xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
return max_accuracy
def _mp_fn(index, flags):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
torch.set_default_tensor_type('torch.FloatTensor')
accuracy = train_mnist(flags)
if flags.tidy and os.path.isdir(flags.datadir):
shutil.rmtree(flags.datadir)
if accuracy < flags.target_accuracy:
print('Accuracy {} is below target {}'.format(accuracy,
flags.target_accuracy))
sys.exit(21)
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
Expected behavior
The script is able to use
Environment
Reproducible on XLA backend [CPU/TPU]: TPU v3-8
torch_xla version: 2.0.0+cu117
torch version: 2.0
Additional context
After the program hangs I exited and this is the stack trace I got
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
replica_results = list(
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
itertools.chain.from_iterable(
File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 439, in result
self._condition.wait(timeout)
File "/usr/lib/python3.8/threading.py", line 302, in wait
waiter.acquire()
KeyboardInterrupt
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "test_train_mp_mnist.py", line 231, in <module>
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
_run_multiprocess(spawn_fn, start_method=start_method)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
replica_results = list(
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 644, in __exit__
self.shutdown(wait=True)
File "/usr/lib/python3.8/concurrent/futures/process.py", line 686, in shutdown
self._queue_management_thread.join()
File "/usr/lib/python3.8/threading.py", line 1011, in join
self._wait_for_tstate_lock()
File "/usr/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
elif lock.acquire(block, timeout):
KeyboardInterrupt
^CError in atexit._run_exitfuncs:
Traceback (most recent call last):
File "/usr/lib/python3.8/concurrent/futures/process.py", line 104, in _python_exit
Process SpawnProcess-4:
Process SpawnProcess-2:
Process SpawnProcess-1:
t.join()
File "/usr/lib/python3.8/threading.py", line 1011, in join
self._wait_for_tstate_lock()
File "/usr/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
elif lock.acquire(block, timeout):
KeyboardInterrupt
Traceback (most recent call last):
File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
call_item = call_queue.get(block=True)
File "/usr/lib/python3.8/multiprocessing/queues.py", line 97, in get
res = self._recv_bytes()
File "/usr/lib/python3.8/multiprocessing/connection.py", line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes
buf = self._recv(4)
Traceback (most recent call last):
Traceback (most recent call last):
File "/usr/lib/python3.8/multiprocessing/connection.py", line 379, in _recv
chunk = read(handle, remaining)
KeyboardInterrupt
File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
call_item = call_queue.get(block=True)
File "/usr/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
call_item = call_queue.get(block=True)
File "/usr/lib/python3.8/multiprocessing/queues.py", line 96, in get
with self._rlock:
File "/usr/lib/python3.8/multiprocessing/queues.py", line 96, in get
with self._rlock:
File "/usr/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
return self._semlock.__enter__()
File "/usr/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
return self._semlock.__enter__()
KeyboardInterrupt
KeyboardInterrupt
^CError in atexit._run_exitfuncs:
Error in atexit._run_exitfuncs:
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
Traceback (most recent call last):
File "/usr/lib/python3.8/multiprocessing/popen_fork.py", line 27, in poll
Traceback (most recent call last):
File "/usr/lib/python3.8/concurrent/futures/thread.py", line 40, in _python_exit
File "/usr/lib/python3.8/concurrent/futures/thread.py", line 40, in _python_exit
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
File "/usr/lib/python3.8/concurrent/futures/thread.py", line 40, in _python_exit
pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt
t.join()
t.join()
File "/usr/lib/python3.8/threading.py", line 1011, in join
File "/usr/lib/python3.8/threading.py", line 1011, in join
t.join()
File "/usr/lib/python3.8/threading.py", line 1011, in join
self._wait_for_tstate_lock()
self._wait_for_tstate_lock()
self._wait_for_tstate_lock()
File "/usr/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
File "/usr/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
File "/usr/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
elif lock.acquire(block, timeout):
KeyboardInterrupt
elif lock.acquire(block, timeout):
elif lock.acquire(block, timeout):
KeyboardInterrupt
KeyboardInterrupt
terminate called without an active exception
https://symbolize.stripped_domain/r/?trace=7f05c690000b,7f05c690008f&map=
*** SIGABRT received by PID 282231 (TID 283863) on cpu 29 from PID 282231; stack trace: ***
PC: @ 0x7f05c690000b (unknown) raise
@ 0x7f0486e19a1a 1152 (unknown)
@ 0x7f05c6900090 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7f05c690000b,7f0486e19a19,7f05c690008f&map=ceee8fa20ddf9c34af43f587221e91de:7f0479ef1000-7f0487030840
E0830 04:14:37.610973 283863 coredump_hook.cc:414] RAW: Remote crash data gathering hook invoked.
E0830 04:14:37.610994 283863 coredump_hook.cc:453] RAW: Skipping coredump since rlimit was 0 at process start.
E0830 04:14:37.611009 283863 client.cc:278] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0830 04:14:37.611016 283863 coredump_hook.cc:512] RAW: Sending fingerprint to remote end.
E0830 04:14:37.611024 283863 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0830 04:14:37.611032 283863 coredump_hook.cc:518] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0830 04:14:37.611040 283863 coredump_hook.cc:580] RAW: Dumping core locally.
humzaiqbal@t1v-n-5fec587b-w-0:~/xla/test$ E0830 04:14:37.805479 283863 process_state.cc:784] RAW: Raising signal 6 with default behavior
which seems to to suggest some sort of multiprocessing issue. Based on some print statements it seems to struggle when this line is called
🐛 Bug
While running the MNIST example using the DDP runtime I notice that it hangs when trying to initialize the process group likely due to multiprocessing issue
To Reproduce
Run this modified version of the
test_train_mnist_mp.py
script using the commandExpected behavior
The script is able to use
Environment
Additional context
After the program hangs I exited and this is the stack trace I got
which seems to to suggest some sort of multiprocessing issue. Based on some print statements it seems to struggle when this line is called