pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
https://pytorch.org/examples
BSD 3-Clause "New" or "Revised" License
22.44k stars 9.55k forks source link

multi-node Tensor Parallel #1257

Open PieterZanders opened 6 months ago

PieterZanders commented 6 months ago

Hello, could you add an new example of the tensor parallel + fsdp but using a multi-node setup? Is it possible to do multi-node tensor parallelization with pytorch 2.3? I am trying to use 2 nodes with 4 GPUs each. 05/12/2024 04:32:52 PM Device Mesh created: device_mesh=DeviceMesh([[0, 1, 2, 3], [4, 5, 6, 7]], mesh_dim_names=('dp', 'tp'))

When I try the actual example on multiple nodes I get the following errors.

Thank you.


as07r1b31:3011779:3012101 [0] init.cc:871 NCCL WARN Duplicate GPU detected : rank 0 and rank 1 both on CUDA device 1b000
as07r1b31:3011783:3012102 [0] init.cc:871 NCCL WARN Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 1b000
as07r1b31:3011782:3012104 [3] init.cc:871 NCCL WARN Duplicate GPU detected : rank 0 and rank 1 both on CUDA device ad000
as07r1b31:3011786:3012107 [3] init.cc:871 NCCL WARN Duplicate GPU detected : rank 1 and rank 0 both on CUDA device ad000
as07r1b31:3011780:3012106 [1] init.cc:871 NCCL WARN Duplicate GPU detected : rank 0 and rank 1 both on CUDA device 2c000
as07r1b31:3011784:3012108 [1] init.cc:871 NCCL WARN Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 2c000
as07r1b31:3011781:3012110 [2] init.cc:871 NCCL WARN Duplicate GPU detected : rank 0 and rank 1 both on CUDA device 9d000
as07r1b31:3011785:3012111 [2] init.cc:871 NCCL WARN Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 9d000

[rank0]: Traceback (most recent call last):
[rank0]:   File "/gpfs/mn4/AE_tp/tests.py", line 91, in <module>
[rank0]:     _, output = sharded_model(inp)
[rank0]:                 ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 843, in forward
[rank0]:     args, kwargs = _pre_forward(
[rank0]:                    ^^^^^^^^^^^^^
[rank0]:   File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 380, in _pre_forward
[rank0]:     unshard_fn(state, handle)
[rank0]:   File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 415, in _pre_forward_unshard
[rank0]:     _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
[rank0]:   File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 299, in _unshard
[rank0]:     handle.unshard()
[rank0]:   File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 1308, in unshard
[rank0]:     padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
[rank0]:                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 1399, in _all_gather_flat_param
[rank0]:     dist.all_gather_into_tensor(
[rank0]:   File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 75, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2948, in all_gather_into_tensor
[rank0]:     work = group._allgather_base(output_tensor, input_tensor, opts)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: torch.distributed.DistBackendError: NCCL error in: /opt/conda/conda-bld/pytorch_1712608847532/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1970, invalid usage (run with NCCL_DEBUG=WARN for details), NCCL version 2.20.5
[rank0]: ncclInvalidUsage: This usually reflects invalid usage of NCCL library.
[rank0]: Last error:
[rank0]: Duplicate GPU detected : rank 0 and rank 1 both on CUDA device 1b000
[same on other ranks]

Traceback (most recent call last):
  File "/home/mn4/AE_tp/mdae2.3/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.3.0', 'console_scripts', 'torchrun')())
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/run.py", line 879, in main
    run(args)
  File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/run.py", line 870, in run
    elastic_launch(
  File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mn4/AE_tp/mdae2.3/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
tests.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-05-12_16:33:02
  host      : as07r1b31.bsc.mn
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3011780)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2024-05-12_16:33:02
  host      : as07r1b31.bsc.mn
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 3011781)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2024-05-12_16:33:02
  host      : as07r1b31.bsc.mn
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 3011782)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[4]:
  time      : 2024-05-12_16:33:02
  host      : as07r1b31.bsc.mn
  rank      : 4 (local_rank: 4)
  exitcode  : 1 (pid: 3011783)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[5]:
  time      : 2024-05-12_16:33:02
  host      : as07r1b31.bsc.mn
  rank      : 5 (local_rank: 5)
  exitcode  : 1 (pid: 3011784)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[6]:
  time      : 2024-05-12_16:33:02
  host      : as07r1b31.bsc.mn
  rank      : 6 (local_rank: 6)
  exitcode  : 1 (pid: 3011785)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[7]:
  time      : 2024-05-12_16:33:02
  host      : as07r1b31.bsc.mn
  rank      : 7 (local_rank: 7)
  exitcode  : 1 (pid: 3011786)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-05-12_16:33:02
  host      : as07r1b31.bsc.mn
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3011779)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
Lucazzn commented 2 weeks ago

Are you sure you added the parallel_modulefunction on this tp_mesh= device_mesh['tp'], I reported the same error earlier, parallel_modulefunction can only accept 1-D device grid layout