microsoft / tutel

Tutel MoE: An Optimized Mixture-of-Experts Implementation
MIT License
723 stars 93 forks source link

Error in load_importance_loss #167

Open Luodian opened 2 years ago

Luodian commented 2 years ago

Hi I had the errors when using load_importance_loss (the code works fine when using gshard_loss). Does anyone have an idea about it?

The error log (in one rank/node) is in below:

[4]:
  time      : 2022-07-06_11:47:24
  host      : SG-IDC1-10-51-2-36
  rank      : 4 (local_rank: 4)
  exitcode  : 1 (pid: 55010)
  error_file: /tmp/torchelastic_kuhg0qco/none_62gucqgc/attempt_0/4/error.json
  traceback : Traceback (most recent call last):
    File "/mnt/lustre/bli/anaconda3/envs/scale/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
      return forward_call(*input, **kwargs)
    File "/mnt/lustre/bli/projects/Pretraining-DG/mae/models_moe_mae.py", line 75, in forward
      x_temp = self.mlp(self.norm2(x))
    File "/mnt/lustre/bli/anaconda3/envs/scale/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
      return forward_call(*input, **kwargs)
    File "/mnt/lustre/bli/.local/lib/python3.9/site-packages/tutel/impls/moe_layer.py", line 231, in forward
      logits_dtype, (crit, l_aux) = routing()
    File "/mnt/lustre/bli/.local/lib/python3.9/site-packages/tutel/impls/moe_layer.py", line 218, in routing
      return logits.dtype, extract_critical(scores,
    File "/mnt/lustre/bli/.local/lib/python3.9/site-packages/tutel/impls/fast_dispatch.py", line 150, in extract_critical
      l_loss = loss_fn(scores, topk_indices) if loss_fn is not None else None
    File "/mnt/lustre/bli/.local/lib/python3.9/site-packages/tutel/impls/moe_layer.py", line 215, in <lambda>
      _loss_fn = lambda gates, topk_ids: losses.load_importance_loss(
    File "/mnt/lustre/bli/.local/lib/python3.9/site-packages/tutel/impls/losses.py", line 41, in load_importance_loss
      l_load = load_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise)
    File "/mnt/lustre/bli/.local/lib/python3.9/site-packages/tutel/impls/losses.py", line 23, in load_loss
      normal = Normal(
    File "/mnt/lustre/bli/anaconda3/envs/scale/lib/python3.9/site-packages/torch/distributions/normal.py", line 54, in __init__
      super(Normal, self).__init__(batch_shape, validate_args=validate_args)
    File "/mnt/lustre/bli/anaconda3/envs/scale/lib/python3.9/site-packages/torch/distributions/distribution.py", line 55, in __init__
      raise ValueError(
  ValueError: Expected parameter scale (Tensor of shape (1,)) of distribution Normal(loc: tensor([0.], device='cuda:4'), scale: tensor([0.], device='cuda:4')) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
  tensor([0.], device='cuda:4')
ghostplant commented 2 years ago

@zeliu98

Luodian commented 2 years ago

Hi I think this happens because the default gate_noise value in load_importance_loss is 0.0.

And if we do

normal = Normal(0, 0.0)

it's weird, why we have a normal distribution with zero variance? and it returns

*** ValueError: Expected parameter scale (Tensor of shape ()) of distribution Normal(loc: 0.0, scale: 0.0) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
0.0
Luodian commented 2 years ago

If I preset gate_noise to 1.0, I think the code run without problems but I am not sure if it's numerically correct?

gate_type={'type': 'top', 'k': 2, 'fp32_gate': False, 'gate_noise': 1.0, },
zeliu98 commented 2 years ago

Hi @Luodian, yes, you need to set gate_noise>0 for load_importance_loss. You can find the reasons in APPENDICES A: LOAD-BALANCING LOSS in the original paper (https://arxiv.org/pdf/1701.06538.pdf).

ghostplant commented 2 years ago

@zeliu98 We need to add assertion reason to avoid unknowns error like this.

And thanks for your information! @Luodian

Luodian commented 2 years ago

Yep, and I also found an issue when using cosine projector.

It seems that in cosine_top.py line 31, there should be an .cuda() or .to(device) flag to make sure the tensor in same device.

logit_scale = torch.clamp(self.temperature, max=torch.log(torch.tensor(1. / 0.01)).cuda()).exp()
ghostplant commented 2 years ago

We have added gate_noise assertion and device cast in latest commit. Thanks for pointing out this bug.