microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.34k stars 4.1k forks source link

[BUG] Expert parallel hangs at the last MoE layer #5794

Open JessePrince opened 3 months ago

JessePrince commented 3 months ago

Describe the bug I'm using DeepSpeed MoE layer to build a multi-modal LLM, I'm using Phi-3 as the base model, and replaced the MLP layer with MoE layer in DeepSpeed. However, when I enabled expert parallel, the communication hangs at the last MoE layer.

To Reproduce Steps to reproduce the behavior:

  1. Load a base Phi-3 model from HuggingFace
  2. Replace self.mlp to MoE layer
  3. set expert_num=4 and ep_size=4
  4. run training with zero stage 2 optimizer

Expected behavior No hangs during training

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/root/miniconda3/lib/python3.10/site-packages/torch']
torch version .................... 2.1.2+cu121
deepspeed install path ........... ['/root/miniconda3/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.14.4+unknown, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.1
shared memory (/dev/shm) size .... 94.00 GB

System info (please complete the following information):

Additional context I used logger to check the whole code base, the odd thing is, DeepSpeed only hangs at the last MoE layer, for Phi-3 mini model, it hangs at 32'nd decoder layer. When I hit ctrl + c, the code stopped at result.wait(), which means something is wrong with communication. During training hangs, GPU 2,3(I got 4 GPUs) have very high communication rate(11GB/s) while GPU 0, 1 got 120MB/s. I also checked the specific line of code where the training stopped, and it's the MoE module in the last decoder layer.

JessePrince commented 3 months ago

Some extra details: I used logger to find out which process is stucked

if isinstance(self.mlp, MoE):
    logger.warning("rank " + os.environ["LOCAL_RANK"] + "hidden_states before mlp: " + str(hidden_states))

hidden_states = self.mlp(hidden_states) # this is MoE layer
if isinstance(self.mlp, MoE):
    logger.warning("rank " + os.environ["LOCAL_RANK"] + "hidden_states after mlp: " + str(hidden_states))

And the logger can only log the hidden_states from rank 1 after the MoE layer.

Inside the MoE layer, in sharded_mode.py, this is caused by the second all to all

JessePrince commented 3 months ago

More investigation results: if token number varies across different ranks, the all_to_all operation hangs To reproduce:

if int(os.environ['LOCAL_RANK']) == 0:
    input_ids = torch.ones(2, 100, device=local_device)
else:
    input_ids = torch.ones(2, 80, device=local_device)

loss = model_engine(input_ids)

This disagreement of token length will cause all_to_all hangs. It works well if token length on different ranks are all the same

taehyunzzz commented 3 weeks ago

Have you found a fix?

JessePrince commented 3 weeks ago

Have you found a fix?

Still no…Pad sequences to a fixed length, I guess?😂

taehyunzzz commented 3 weeks ago

Yeah, I found the same cause... the multi-GPU processes are not using the same sequence lengths. You'll need to modify the _All2All function and manually set the input/output split sizes in the dist.all_to_allsingle functions with the number of tokens to exchange between your GPUs. But again, I agree that padding is the inefficient but simpler way ~~.

JessePrince commented 3 weeks ago

Yeah, I found the same cause... the multi-GPU processes are not using the same sequence lengths. You'll need to modify the _All2All function and manually set the input/output split sizes in the dist.all_to_allsingle functions with the number of tokens to exchange between your GPUs. But again, I agree that padding is the inefficient but simpler way ~~.

Thanks for the sharing! Hope DeepSpeed will support this one day🙏