Open santurini opened 1 year ago
Tutel MoE works just like it is in DDP modes for data loaders and models, so you can safely stack the Tutel MoE layer in your original forward graph design for DDP. (e.g. https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L79)
There is only one thing you need to pay attention: All expert parameters shouldn't be managed by DDP allreduce. And this is how you can achieve this goal:
model = model.to('cuda:#')
), call the handler: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L92-L95Thank you very much, I'll try it out! And any explanation about the different type of parallel arguments and what are the differences between them?
Whatever type of parallel you choose, it doesn't change how to use MoE layer out of the box. Different types of parallel just change the MoE internal parallelism to use, but those choices are all transparent to users & also math-equivalent with each other.
For large scales / small scales, smartly setting of that option will improve the execution time of Tutel MoE layer, since each different parallelism has its particular network complexity and local memory consumption.
Additionally, you can also change the parallel option for every different iteration. e.g. https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_switch.py#L88
The value of adaptive_r
varies from 0 to max(1, [Total GPU Count / Total Expert Count])
Wow, thank you very much I'll try it out!
Tutel MoE works just like it is in DDP modes for data loaders and models, so you can safely stack the Tutel MoE layer in your original forward graph design for DDP. (e.g. https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L79)
There is only one thing you need to pay attention: All expert parameters shouldn't be managed by DDP allreduce. And this is how you can achieve this goal:
- Set an attribution as mask on each expert parameter object (follow the example): https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L68
- Follow Pytorch DDP to add the handler to skip them: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L83-L87
- Just after model initialization to target device (e.g.
model = model.to('cuda:#')
), call the handler: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L92-L95
Hello,I use both deepspeed and your tutel framework, 4 experts and 4 gpus, both Data and Expert Parallelism. To avoid All expert parameters be managed by deepspeed allreduce, code :
deepspeed.initialize( args=self.args, model=model, model_parameters = [param for name,param in model.named_parameters() if not hasattr(param, "skip_allruduce")], config=ds_config )
but it doesn't work,expert parameters cann't be updated correctly. how should I do ? Thanks!
Is skip_allruduce
a typo? I think you need to set something in DeepSpeed in order to bypass being all-reduced
on all model parameters that satisfy getattr(param, '_tutel_expert', False) == True
(refered here).
In order words, below is exactly the parameter list you need to bypass doing allreduce:
bypass_list = [x for x in model.parameters() if getattr(param, '_tutel_expert', False) == True]
Next, you need to ask for Deepspeed's doc about how their framework can avoid doing all_reduce for bypass_list
you provided above.
Thanks,But I can't find the result how to avoid doing all_reduce.
Thanks,But I can't find the result how to avoid doing all_reduce.
Tha action of skipping all-reduce is controlled out of Tutel MoE layer, e.g. DeepSpeed/Fairseq/Megatron, so your remaining question has to be answered by DeepSpeed.
If I placed all experts to specific gpu just by setting custom processed group during the creation of moe_layer(),I don't need to do all_reduce,it's right?
If I placed all experts to specific gpu just by setting custom processed group during the creation of moe_layer(),I don't need to do all_reduce,it's right?
Nop, I don't think setting custom processed group to bypass all_reduce
is a standard usage in Pytorch distributed.
This is Pytorch DDP's standard way to bypass allreduce, but Deepspeed may maintain its own allreduce in a different way: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L84-L98
If you have interests, can you share an example of Deepspeed for training in your context? Given a reproducible training script, we can help you with the answer.
How should I prepare my code (data loaders, model, etc..) in order to train in a both Data and Expert Parallel mode? And what does it change from "auto", "model" and "data" --parallel type?
In my current setup I'm training in DDP wrapping the model with torch DistributedDataParallel and using the distributed sampler in the loaders. Now I wanted to insert a MoE in the model with 2 experts (I have 2 gpus so 1 local expert) so using both Data and Expert Parallelism. Some help would be appreciated.