NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.43k stars 1.41k forks source link

[NCCL] Prevent premature destroy of PGs following PyTorch upstream change #1859

Closed eqy closed 1 week ago

eqy commented 1 week ago

This PR explicitly sets the destroy_process_group argument of run_test to False. Otherwise, https://github.com/pytorch/pytorch/pull/140820/ adds a default True argument to run_test which changes the default behavior of run_test to destroy the PG, so the subsequent dist.destroy_process_group() call in distributed_test_base.py does a premature destroy leading to error messages like:

Process process 0:
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.12/dist-packages/apex/transformer/testing/distributed_test_base.py", line 71, in _run
    dist.barrier()
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4462, in barrier
    opts.device = torch.device(_get_object_coll_device(group))
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 774, in _get_object_coll_device
    group = group or _get_default_group()
                     ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 1276, in _get_default_group
    raise ValueError(
ValueError: Default process group has not been initialized, please make sure to call init_process_group.

in the following barrier.