microsoft / tutel

Tutel MoE: An Optimized Mixture-of-Experts Implementation
MIT License
694 stars 84 forks source link

Question regarding the load importance loss calculation #240

Open wangyirui opened 2 months ago

wangyirui commented 2 months ago

Hi, when studying the load importance loss, I found the parameters passed to the function load_importance_loss are softmax normalized scores and logits with noise (see moe_layer.py L281). I am wondering why we use the softmax normalized score to calculate the diff against the raw logits with noise? Why not consistently use the softmax output for both? Thanks!

ghostplant commented 2 months ago

Hi, the standard GShard MoE follows the branch self.is_gshard_loss == True, while the loss option you pointed out is designed and preferred by Swin-Transformer MoE.

According to load_importance_loss defined in https://github.com/microsoft/tutel/blob/main/tutel/impls/losses.py#L29, it requires normalization to perform directly on the score tensor without doing noise which avoids normalization results to be polluted by the noise. @zeliu98