Open jinga-lala opened 2 years ago
Can you upgrade and keep NCCL version the same on all environments? Most of NCCL timeout issues were from libnccl legacy bugs, or inconsistent NCCL version problems.
You can also run tutel.examples.helloworld
in the same distributed setting to test whether it has the same NCCL timeout error.
Thanks @ghostplant for your reply. The example tutel.examples.helloworld
works perfectly fine in single-gpu and multi-gpus, singe node setting. That's why, I think it is not a NCCL version issue.
My code works fine in single-gpu setting but in multi-gpu single node configuration it crashes with this error.
I dug deeper and found that the execution freezes at this line in the code. https://github.com/microsoft/tutel/blob/17f4aab9b69cf50dcddd2b985907126379af1568/tutel/impls/moe_layer.py#L292
OK, since tutel.examples.helloworld
works well, it should be related to inequivalent data sources stored on each GPU, which results in different planned iteration counts locally and thus triggers different number of model forwarding function. So, such timeout has to be solved at application side. But you can still try whether enabling both 2 following options can get rid of this problem: (1) setting capacity_factor = negative_value
inside moe_layer creation in transformer initialization function; (2) always enabling _moe_layer_0.forward(.., inequivalent_tokens=True)
in transformer forwarding function.
If the combination above doesn't work, you have to change the way of data feeding in application side to guarantee all GPU always have same forwarding counts and execution orders.
It worked when I set _moe_layer_0.forward(.., inequivalent_tokens=True)
🎉 .
Is it because in object detection the image sizes are different and so are the number of patches for forward pass in each of the moe models?
Just for clarification, in this DDP setting there are separate copies of local experts on each GPU but the data batch is divided among the GPUs, right? Also, the common architecture is copied on each GPU and is being synced after each pass?
Thank you @ghostplant for your help!
Since inequivalent_tokens=True
works, it means there is no issue from "inequivalent forwarding counts". (See Case-1)
It is only helpful when for each iteration, the "tokens per batch" on each device is not the same with others. (See Case-2)
Case-1: where inequivalent_tokens=True
is NOT helpful
[GPU-0] [GPU-1] [...]
epoch0-step0 epoch0-step0
epoch0-step1 epoch0-step1
... ...
epoch0-step100 epoch0-step100
epoch0-step101 epoch1-step0 <--
epoch1-step0 epoch1-step1
... ...
Case-2: where inequivalent_tokens=True
is helpful
[GPU-0] [GPU-1]
step-0 (bs=16) step-0 (bs=16)
step-1 (bs=16) step-1 (bs=16)
... ...
step-50 (bs=16) step-50 (bs=16)
step-51 (bs=3) step-51 (bs=11) <--
... ...
Hi, I am using Tutel library with MMAction framework to replicate Swin-v2 MoE performance described in the paper. However, I am facing this error when I try to train MoE in DDP setting. Can someone please help me in resolving this error? Alternatively, can you release the object detection code that was used in the Tutel paper.