microsoft / tutel

Tutel MoE: An Optimized Mixture-of-Experts Implementation
MIT License
694 stars 84 forks source link

Training with Data and Expert Parallelism #204

Open santurini opened 1 year ago

santurini commented 1 year ago

How should I prepare my code (data loaders, model, etc..) in order to train in a both Data and Expert Parallel mode? And what does it change from "auto", "model" and "data" --parallel type?

In my current setup I'm training in DDP wrapping the model with torch DistributedDataParallel and using the distributed sampler in the loaders. Now I wanted to insert a MoE in the model with 2 experts (I have 2 gpus so 1 local expert) so using both Data and Expert Parallelism. Some help would be appreciated.

ghostplant commented 1 year ago

Tutel MoE works just like it is in DDP modes for data loaders and models, so you can safely stack the Tutel MoE layer in your original forward graph design for DDP. (e.g. https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L79)

There is only one thing you need to pay attention: All expert parameters shouldn't be managed by DDP allreduce. And this is how you can achieve this goal:

  1. Set an attribution as mask on each expert parameter object (follow the example): https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L68
  2. Follow Pytorch DDP to add the handler to skip them: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L83-L87
  3. Just after model initialization to target device (e.g. model = model.to('cuda:#')), call the handler: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L92-L95
santurini commented 1 year ago

Thank you very much, I'll try it out! And any explanation about the different type of parallel arguments and what are the differences between them?

ghostplant commented 1 year ago

Whatever type of parallel you choose, it doesn't change how to use MoE layer out of the box. Different types of parallel just change the MoE internal parallelism to use, but those choices are all transparent to users & also math-equivalent with each other.

For large scales / small scales, smartly setting of that option will improve the execution time of Tutel MoE layer, since each different parallelism has its particular network complexity and local memory consumption.

ghostplant commented 1 year ago

Additionally, you can also change the parallel option for every different iteration. e.g. https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_switch.py#L88

The value of adaptive_r varies from 0 to max(1, [Total GPU Count / Total Expert Count])

santurini commented 1 year ago

Wow, thank you very much I'll try it out!