microsoft / tutel

Tutel MoE: An Optimized Mixture-of-Experts Implementation
MIT License
710 stars 85 forks source link

Error when doing deepcopy of the model #177

Open yzxing87 opened 2 years ago

yzxing87 commented 2 years ago

Hi, thanks for this awesome project!

I build my transformer model based on the MoeMlp layer. I use ema for better performance. However, when I trying to init my ema model with ema_model = copy.deepcopy(my_transformer_model), I encounter the error:

File "/opt/conda/lib/python3.8/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/opt/conda/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/opt/conda/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/opt/conda/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/opt/conda/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/opt/conda/lib/python3.8/copy.py", line 161, in deepcopy
    rv = reductor(4)
TypeError: cannot pickle 'torch._C._distributed_c10d.ProcessGroupNCCL' object

Could you help me with that? How can I use ema with tutel? Thanks!

ghostplant commented 2 years ago

If you pickle the model for single GPU, anything will be fine because AllToAll is not included in Tutel's MoE layer. Is that okay to your expectation?

Pytorch's NCCL operations (e.g. AllToAll) don't support pickle, so I'm afraid that any MoE models having AllToAll in their forward pass will have the same issue. You may either ask Torch community to fix it up, or you do any workaround in ema that doesn't require deepcopy, or you can only boot MoE model in data parallel mode, which wouldn't have AllToAll in the forward pass, though the distributed performance will be very bad in large scale.

yzxing87 commented 2 years ago

Thanks for your prompt reply! I use 8 GPUs to train my model with 8 experts. In that case, can I pickle my model for a single GPU? I also would like to know if it is a requirement to save the checkpoints separately for each rank? Does it always need 8 GPUs for inference if I trained the model with 8 GPUs?

ghostplant commented 2 years ago

You can go through these examples to convert training checkpoints between distributed version and single-device version: https://github.com/microsoft/tutel#how-to-convert-checkpoint-files-that-adapt-to-different-distributed-world-sizes

yzxing87 commented 2 years ago

Thanks for your quick update for this feature! I notice you use mpiexec to launch the job and save the ckpt. If I use torch.distributed.launch to train my moe, is it still valid to use the tutel/checkpoint/gather.py to combine my checkpoints?

ghostplant commented 2 years ago

Thanks for your quick update for this feature! I notice you use mpiexec to launch the job and save the ckpt. If I use torch.distributed.launch to train my moe, is it still valid to use the tutel/checkpoint/gather.py to combine my checkpoints?

Yes, both are compatible, as mpiexec is just an alternative way to launch cross-node processes instead of torch.distributed.launch.