NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
9.51k stars 2.15k forks source link

[BUG] Bug of expert model parallel #766

Open 1049451037 opened 3 months ago

1049451037 commented 3 months ago

https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py#L503

0: [rank1]:   File "/share/home/lqs/Megatron-LM/megatron/core/parallel_state.py", line 503, in initialize_model_parallel
0: [rank1]:     group = torch.distributed.new_group(
0: [rank1]:   File "/share/real_shared_envs/megatron/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 89, in wrapper
0: [rank1]:     func_return = func(*args, **kwargs)
0: [rank1]:   File "/share/real_shared_envs/megatron/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3806, in new_group
0: [rank1]:     return _new_group_with_tag(ranks, timeout, backend, pg_options, None, use_local_synchronization=use_local_synchronization)
0: [rank1]:   File "/share/real_shared_envs/megatron/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3877, in _new_group_with_tag
0: [rank1]:     pg, pg_store = _new_process_group_helper(
0: [rank1]:   File "/share/real_shared_envs/megatron/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1431, in _new_process_group_helper
0: [rank1]:     pg: ProcessGroup = ProcessGroup(prefix_store, group_rank, group_size, base_pg_options)
0: [rank1]: TypeError: __init__(): incompatible constructor arguments. The following argument types are supported:
0: [rank1]:     1. torch._C._distributed_c10d.ProcessGroup(arg0: int, arg1: int)
0: [rank1]:     2. torch._C._distributed_c10d.ProcessGroup(arg0: torch._C._distributed_c10d.Store, arg1: int, arg2: int, arg3: c10d::ProcessGroup::Options)

As shown in the code:


    tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size
    num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size
    tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size
    num_expert_groups: int = data_parallel_size // expert_model_parallel_size
    for i in range(num_tensor_and_data_groups):
        for j in range(num_expert_groups):
            # TPxEP Group
            start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size
            end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size
            ranks = range(start_rank, end_rank)
            group = torch.distributed.new_group(
                ranks, timeout=timeout, pg_options=get_nccl_options('tp_exp', nccl_comm_cfgs)
            )
            if rank in ranks:
                _TENSOR_AND_EXPERT_PARALLEL_GROUP = group
            for k in range(tensor_model_parallel_size * context_parallel_size):
                ranks = range(
                    start_rank + k, end_rank, tensor_model_parallel_size * context_parallel_size
                )
                group = torch.distributed.new_group(
                    ranks, pg_options=get_nccl_options('exp', nccl_comm_cfgs)
                )
                if rank in ranks:
                    _EXPERT_MODEL_PARALLEL_GROUP = group

The length of [start_rank, end_rank) is tp*ep. But the for loop in k is tp*cp. If cp>ep, it will make ranks empty, which causes the error.

yanring commented 3 months ago

Thank you for let us know! We have a fix, but it's not yet merged. Temporarily WAR is replace tensor_model_parallel_size * context_parallel_size with just tensor_model_parallel_size.

yanring commented 3 months ago

This issue should has been resolved on https://github.com/NVIDIA/Megatron-LM/commit/b5aba3a2f3165da8b4f6b483bf3a6da2a24718e4

github-actions[bot] commented 3 weeks ago

Marking as stale. No activity in 60 days.