NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
9.86k stars 2.23k forks source link

Pytorch distributed runtime check failure when using pipeline parallelism #209

Closed insujang closed 1 month ago

insujang commented 2 years ago

Hope this issue post would be helpful to others who suffer from similar problem. I am trying to run examples/pretrain_gpt_distributed_with_mp.sh, but when pipeline model parallelism is enabled, the following error occurs on every nodes:

Traceback (most recent call last):
  File "pretrain_gpt.py", line 127, in <module>
    pretrain(train_valid_test_datasets_provider, model_provider,
  File "/data/insujang/Megatron-LM/megatron/training.py", line 147, in pretrain
    iteration = train(forward_step_func,
  File "/data/insujang/Megatron-LM/megatron/training.py", line 695, in train
    train_step(forward_step_func,
  File "/data/insujang/Megatron-LM/megatron/training.py", line 398, in train_step
    losses_reduced = forward_backward_func(
  File "/data/insujang/Megatron-LM/megatron/schedules.py", line 381, in forward_backward_pipelining_with_interleaving
    p2p_communication.send_forward_recv_forward(
  File "/data/insujang/Megatron-LM/megatron/p2p_communication.py", line 270, in send_forward_recv_forward
    input_tensor, _ = _communicate(
  File "/data/insujang/Megatron-LM/megatron/p2p_communication.py", line 124, in _communicate
    send_next_op = torch.distributed.P2POp(
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 827, in __new__
    _check_op(op)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 263, in _check_op
    raise RuntimeError("Invalid ``op``. Expected ``op`` "
RuntimeError: Invalid ``op``. Expected ``op`` to be of type ``torch.distributed.isend`` or ``torch.distributed.irecv``.

(for workers error occurs from the receive part but by the same error)

Glancing at the code, I first did not understand why the code does not work since all arguments looked good:

def _check_op(op):
    """
    Helper to check that the ``op`` is either isend or irecv.
    """
    if op not in [isend, irecv]:
        raise RuntimeError("Invalid ``op``. Expected ``op`` "
                           "to be of type ``torch.distributed.isend`` or "
                           "``torch.distributed.irecv``.")

send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_next,
                mpu.get_pipeline_model_parallel_next_rank())

Strangely torch.distributed.P2POp does work if I manually test it with the following example:

import torch
tensor = torch.rand(2)
op = torch.distributed.P2POp(torch.distributed.isend, tensor, 1)

which means, it is not about a Pytorch problem.

When I print op in _check_op function (adding print(op)), the output is different when I run Megatron.

Manual test:

<function isend at [address]>

Megatron:

<function add_wrapper.<locals>.wrapper_func at [address]>

which makes the _check_op fails and makes a runtime error.

To prevent it, change P2POp creation function like this:

from torch.distributed import isend, irecv

...

if tensor_send_next is not None:
  send_next_op = torch.distributed.P2POp(isend, tensor_send_next, mpu.get_pipeline_model_parallel_next_rank())
  ops.append(send_next_op)

I am not familiar with such Python-related errors, so any suggestion of a better solution would be appreciated.

github-actions[bot] commented 1 year ago

Marking as stale. No activity in 60 days. Remove stale label or comment or this will be closed in 7 days.