pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

Processes hang when running multiprocess training with xmp.spawn, and one rank waits at rendezvous while another hits assertion #7974

Open jeffhataws opened 2 months ago

jeffhataws commented 2 months ago

🐛 Bug

When I run two worker multiprocess training with xmp.spawn, where one rank waits at rendezvous while another hits assertion, I would see a hang at the assertion:

(pt25dev_env) ubuntu@ip-10-3-206-226:~$ PJRT_DEVICE=CPU CPU_NUM_DEVICES=2 python test_assert_hang.py 
WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.6. is deprecated. Use torch_xla.runtime.global_ordinal instead.
Rank 1 continues to rendezvous.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1725834192.702937  625990 collective_ops_utils.h:310] This thread has been waiting for 5000ms for and may be stuck: participant AllReduceParticipantData{rank=1, element_count=7840, type=F32, rendezvous_key=RendezvousKey{run_id=RunId: 0, global_devices=[0,1], num_local_participants=2, collective_op_kind=cross_replica, op_id=0}} waiting for all participants to arrive at rendezvous RendezvousKey{run_id=RunId: 0, global_devices=[0,1], num_local_participants=2, collective_op_kind=cross_replica, op_id=0}

(hangs forever)

Interrupting using ctrl-C would still leave two zombie processes.

When running with torchrun, I see the run errors out properly:

(pt25dev_env) ubuntu@ip-10-3-206-226:~$ PJRT_DEVICE=CPU torchrun --nproc_per_node=2 test_assert_hang.py                    
W0908 22:30:28.670000 629534 torch/distributed/run.py:793] 
W0908 22:30:28.670000 629534 torch/distributed/run.py:793] *****************************************
W0908 22:30:28.670000 629534 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overlo
aded, please further tune the variable for optimal performance in your application as needed. 
W0908 22:30:28.670000 629534 torch/distributed/run.py:793] *****************************************
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.6. is deprecated. Use torch_xla.runtime.global_ordinal instead.
[rank0]: Traceback (most recent call last): 
[rank0]:   File "/home/ubuntu/test_assert_hang.py", line 49, in <module>
[rank0]:     _mp_fn(0)              
[rank0]:   File "/home/ubuntu/test_assert_hang.py", line 41, in _mp_fn
[rank0]:     assert (False), "Test rank 0 asserting (hangs with xmp.spawn, exits with error properly with torchrun."
[rank0]: AssertionError: Test rank 0 asserting (hangs with xmp.spawn, exits with error properly with torchrun.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.6. is deprecated. Use torch_xla.runtime.global_ordinal instead.
[rank0]: Traceback (most recent call last): 
[rank0]:   File "/home/ubuntu/test_assert_hang.py", line 49, in <module>
[rank0]:     _mp_fn(0)                                                                                                                                                        
[rank0]:   File "/home/ubuntu/test_assert_hang.py", line 41, in _mp_fn
[rank0]:     assert (False), "Test rank 0 asserting (hangs with xmp.spawn, exits with error properly with torchrun."
[rank0]: AssertionError: Test rank 0 asserting (hangs with xmp.spawn, exits with error properly with torchrun.
W0908 22:30:34.653000 629534 torch/distributed/elastic/multiprocessing/api.py:890] Sending process 629555 closing signal SIGTERM
E0908 22:30:34.684000 629534 torch/distributed/elastic/multiprocessing/api.py:862] failed (exitcode: 1) local_rank: 1 (pid: 629556) of binary: /home/ubuntu/pt25dev_env/bin/py
thon
Traceback (most recent call last):
  File "/home/ubuntu/pt25dev_env/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')())
  File "/home/ubuntu/pt25dev_env/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
  File "/home/ubuntu/pt25dev_env/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
    run(args)
  File "/home/ubuntu/pt25dev_env/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "/home/ubuntu/pt25dev_env/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/ubuntu/pt25dev_env/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

To Reproduce

Create a file test_assert_hang.py with the following content:

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

class MLP(nn.Module):
  def __init__(self, input_size = 28 * 28, output_size = 10):
      super(MLP, self).__init__()
      self.fc1 = nn.Linear(input_size, output_size, bias=False)

  def forward(self, x):
      x = self.fc1(x)
      return F.log_softmax(x, dim=1)

def _mp_fn(index):
    dev = xm.xla_device()

    train_x = torch.rand((1, 28 * 28)).to(dev)
    train_label = torch.tensor([5]).to(dev)

    t_model = MLP().to(dev)
    t_model.train()
    loss_fn = nn.NLLLoss()
    optimizer = optim.SGD(t_model.parameters(), lr=0.01)
    t_model.zero_grad()
    output = t_model(train_x)
    t_loss = loss_fn(output, train_label)
    t_loss.backward()
    xm.optimizer_step(optimizer)
    if xm.get_ordinal() == 0:
        assert (False), "Test rank 0 asserting (hangs with xmp.spawn, exits with error properly with torchrun."
    else:
        print("Rank 1 continues to rendezvous.")
    xm.rendezvous("end")

if __name__ == '__main__':
  if os.environ.get("WORLD_SIZE"):
      dist.init_process_group("xla", init_method="xla://")
      _mp_fn(0)
  else:
      xmp.spawn(_mp_fn)

Run the xmp.spawn hang case with CPU on 2 workers:

PJRT_DEVICE=CPU CPU_NUM_DEVICES=2 python test_assert_hang.py

Run the torchrun no-hang case with CPU on 2 workers:

torchrun --nproc_per_node=2 test_assert_hang.py

Expected behavior

The run should exit with assertion message, rather than hang.

Environment

Additional context

JackCaoG commented 2 months ago

hmm interesting, seems like xmp.spawn does not surface the assertion error from spawned process, do you know why?

jeffhataws commented 2 months ago

Sorry I haven't had time to look into it more. The last time I looked, the threads are stuck at futex.

JackCaoG commented 2 months ago

ok if you remove the xm.rendezvous("end") this is going to work. The issue is that current implementation of the multiprocess will not handle the exception from the subprocess until all process finished. I am not an expert on this, but I asked the chatbot and it told me to update the code to like

def run_multiprocess(fn: Callable[..., R],
                     *args,
                     start_method: str = 'spawn',
                     **kwargs) -> Dict[int, R]:
    """Runs `fn` on all devices available to PjRt.

    Spawns one process per physical device (e.g. TPU chip).

    Args:
      fn: Function to run on all devices
      args: args to pass to `fn`
      start_method: The Python `multiprocessing` process creation method.
        Default: `spawn`
      kwargs: kwargs to pass to `fn`

    Returns:
      Dict of the form {device_ordinal: return_value}, where
      return_value is the result of calling `fn`.

    Raises:
      Exception: If any subprocess raises an exception.
    """
    if torch_xla._XLAC._xla_runtime_is_initialized():
        raise RuntimeError('Runtime is already initialized. Do not use the XLA '
                           'device before calling xmp.spawn.')

    # Determine the number of processes
    if plugins.using_dynamic_plugins():
        num_processes = plugins.default().physical_chip_count()
    elif runtime.device_type() == 'TPU':
        num_processes = tpu.num_local_processes()
    elif runtime.device_type() == 'CUDA':
        num_processes = gpu.num_local_processes()
    elif runtime.device_type() == 'NEURON':
        num_processes = neuron.num_local_processes()
    else:
        num_processes = 1

    # Create a queue for error reporting
    error_queue = multiprocessing.Queue()

    with concurrent.futures.ProcessPoolExecutor(
        max_workers=num_processes,
        mp_context=multiprocessing.get_context(start_method)) as executor:

        mp_fn = functools.partial(
            _run_thread_per_device,
            local_world_size=num_processes,
            fn=functools.partial(fn, *args, **kwargs),
            initializer_fn=initialize_multiprocess,
            error_queue=error_queue)

        futures = [executor.submit(mp_fn, i) for i in range(num_processes)]

        results = []
        while futures:
            # Check the error queue
            try:
                error = error_queue.get_nowait()
                # If we get here, an error was reported
                executor.shutdown(wait=False)
                raise error
            except queue.Empty:
                # No error reported, continue processing
                pass

            # Wait for the next future to complete
            done, futures = concurrent.futures.wait(
                futures, timeout=0.1, return_when=concurrent.futures.FIRST_COMPLETED)

            for future in done:
                try:
                    result = future.result()
                    results.append(result)
                except Exception as e:
                    executor.shutdown(wait=False)
                    raise  # Re-raise the exception to stop the main process

    replica_results = list(itertools.chain.from_iterable(result.items() for result in results))
    return _merge_replica_results(replica_results)

def _run_thread_per_device(device_ordinal: int,
                           local_world_size: int,
                           fn: Callable[[], R],
                           initializer_fn: Callable[[], None],
                           error_queue: multiprocessing.Queue) -> Dict[int, R]:
    try:
        initializer_fn()
        result = fn()
        return {device_ordinal: result}
    except Exception as e:
        error_queue.put(e)
        raise  # Re-raise the exception to stop this process

which adds a try catch in _run_thread_per_device and append the expcetion to a queue while the main process keep checking if any new exception in the queue. From the first glance this should work. @will-cromar wdyt?