databricks / megablocks

Apache License 2.0
1.17k stars 169 forks source link

selective router precision #91

Open 152334H opened 8 months ago

152334H commented 8 months ago

To my understanding -- and please correct me if I am wrong about this -- there is no mechanism to selectively compute routing logits in fp32, as is suggested in e.g. switch transformers. Basis:

  1. The only mention of fp32/float computations I see anywhere are for moe_lbl_in_fp32
  2. the router is initialized with the same dtype as the MLP weights (as configured by Arguments).
  3. There does not seem to be any explicit casting || autocast deactivation in router.py, nor any attempt to do so in dMoE
  4. Given that the router is implemented as a torch.nn.Linear, and the input to the router is pre-casted to autocast's precision, I can only presume that the computation must be done in half precision under normal AMP training.

Is this correct? If so, have you observed any instabilities in practice during training? Perhaps it is just not necessary...

tgale96 commented 8 months ago

Hi! No, we don't support their selective precision. Although it would be quite easy to add if you wanted to try it!

In practice we haven't had any issues with router instability, although that paper trains models much larger (and with different systems/software) than what we have. If you're training models with FLOPs equivalent to dense models of 10B parameters or less I suspect you will be fine, based on our experience.