Open JessePrince opened 3 months ago
Some extra details: I used logger to find out which process is stucked
if isinstance(self.mlp, MoE):
logger.warning("rank " + os.environ["LOCAL_RANK"] + "hidden_states before mlp: " + str(hidden_states))
hidden_states = self.mlp(hidden_states) # this is MoE layer
if isinstance(self.mlp, MoE):
logger.warning("rank " + os.environ["LOCAL_RANK"] + "hidden_states after mlp: " + str(hidden_states))
And the logger can only log the hidden_states from rank 1 after the MoE layer.
Inside the MoE layer, in sharded_mode.py, this is caused by the second all to all
More investigation results: if token number varies across different ranks, the all_to_all operation hangs To reproduce:
if int(os.environ['LOCAL_RANK']) == 0:
input_ids = torch.ones(2, 100, device=local_device)
else:
input_ids = torch.ones(2, 80, device=local_device)
loss = model_engine(input_ids)
This disagreement of token length will cause all_to_all hangs. It works well if token length on different ranks are all the same
Have you found a fix?
Have you found a fix?
Still no…Pad sequences to a fixed length, I guess?😂
Yeah, I found the same cause... the multi-GPU processes are not using the same sequence lengths. You'll need to modify the _All2All function and manually set the input/output split sizes in the dist.all_to_allsingle functions with the number of tokens to exchange between your GPUs. But again, I agree that padding is the inefficient but simpler way ~~.
Yeah, I found the same cause... the multi-GPU processes are not using the same sequence lengths. You'll need to modify the _All2All function and manually set the input/output split sizes in the dist.all_to_allsingle functions with the number of tokens to exchange between your GPUs. But again, I agree that padding is the inefficient but simpler way ~~.
Thanks for the sharing! Hope DeepSpeed will support this one day🙏
Describe the bug I'm using DeepSpeed MoE layer to build a multi-modal LLM, I'm using Phi-3 as the base model, and replaced the MLP layer with MoE layer in DeepSpeed. However, when I enabled expert parallel, the communication hangs at the last MoE layer.
To Reproduce Steps to reproduce the behavior:
Expected behavior No hangs during training
ds_report output
System info (please complete the following information):
Additional context I used logger to check the whole code base, the odd thing is, DeepSpeed only hangs at the last MoE layer, for Phi-3 mini model, it hangs at 32'nd decoder layer. When I hit ctrl + c, the code stopped at
result.wait()
, which means something is wrong with communication. During training hangs, GPU 2,3(I got 4 GPUs) have very high communication rate(11GB/s) while GPU 0, 1 got 120MB/s. I also checked the specific line of code where the training stopped, and it's the MoE module in the last decoder layer.