microsoft / tutel

Tutel MoE: An Optimized Mixture-of-Experts Implementation
MIT License
723 stars 93 forks source link

bp of shared parameters and experts #161

Open a157801 opened 2 years ago

a157801 commented 2 years ago

The ddp in pytorch can not distinguish experts and other shared parameters. And experts may be updated with shared gradient. The TutelDistributedOptimizer seems to be an implementation of zero, which does not affect the graident. How does tutel deal with the problem?

ghostplant commented 2 years ago

Yes, TutelDistributedOptimizer is a replacement of Pytorch DDP in that example (helloworld_ddp_tutel) to make the whole model sychronization transparent.

TutelDistributedOptimizer not only implements ZeRO optimization, but also leverages built-in mask (_tutel_expert) to distinguish whether a parameter is shared or from the creation of tutel.moe.moe_layer.

Note that TutelDistributedOptimizer only treats parameters created by tutel.moe.moe_layer to be expert parameters. If the model never uses tutel.moe.moe_layer, there is no difference with Pytorch DDP (expect TutelDistributedOptimizer includes ZeRO feature).

a157801 commented 2 years ago

Thank you for your answer. I notice that _tutel_expert flag is used to split the parameters. But it seems that gradient of experts with _tutel_expert will also be allreduced by ddp. The _tutel_expert flag indicates these parameters are experts and will not be splitted on different gpus, but does not controll the allreduce operation.

ghostplant commented 2 years ago

To use TutelDistributedOptimizer which has parameter synchronization included, you should no longer warp the model with DistributedDataParallel.

a157801 commented 2 years ago

I notice the code in swin-transformer repo(https://github.com/microsoft/Swin-Transformer/blob/main/main_moe.py), which uses pytorch optimizer and ddp to train these moe models. Maybe there is something wrong. Thanks a lot.

ghostplant commented 2 years ago

It is a version that manually distinguish parameter types, which follows helloworld_ddp.py

a157801 commented 2 years ago

Does it work by setting skip_allreduce as true in the scan function?

ghostplant commented 2 years ago

To use tutel moe in Pytorch DDP backend, you need to not only set skip_allreduce as true in the moe scan function, but also recollect parameters with those masks, and tell DDP to skip synchronizing them by: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L92. Otherwise, Pytorch DDP won't know they are expert parameters, so they'll be synchronized unexpectedly.