microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.98k stars 4.06k forks source link

[REQUEST] Expert Choice Routing for MoE #2517

Open clumsy opened 1 year ago

clumsy commented 1 year ago

Is your feature request related to a problem? Please describe. A paper was published regarding potentially better token-expert routing for MoE that leaves less experts under-trained.

Describe the solution you'd like In addition to GShard's top2 and SwitchTransformer's top1 per token expert routing add expert choice routing option.

Describe alternatives you've considered N/A

Additional context N/A

clumsy commented 1 year ago

The authors claim 2x convergence rate with EC routing: https://ai.googleblog.com/2022/11/mixture-of-experts-with-expert-choice.html

I hope this incentivizes implementing it in DeepSpeed.

awan-10 commented 1 year ago

Thank you @clumsy for sharing this paper.

@ykim362, have you seen this paper? Is anyone in your team or any interns interested in implementing this feature?

clumsy commented 1 year ago

In case this helps, TL;DR is in Lilian Weng's blog post.

ykim362 commented 1 year ago

Hi @awan-10 . I have an implementation of this paper. But, we didn't see the gains mentioned in the paper. Actually, the accuracy was quite worse than the original top-1 and top-2 gating.

@clumsy have you actually done any experiments with this expert choice gating?

clumsy commented 1 year ago

No @ykim362, but I would like to experiment with it and share the results. Is it possible to share the snippet with the implementation you used?

ykim362 commented 1 year ago

@clumsy you can take a look at this experimental branch. https://github.com/ykim362/DeepSpeed/tree/youki/expc

ilyalasy commented 6 months ago

hey, google has implementation of expert choice routing here: https://github.com/google/flaxformer/blob/main/flaxformer/architectures/moe/routing.py#L647-L717

They have a note that it should not be used in decoder blocks, maybe that was reason for poor results during your experiments?