bigscience-workshop / Megatron-DeepSpeed

Ongoing research training transformer language models at scale, including: BERT & GPT-2
Other
1.33k stars 215 forks source link

[embed norm] switch to apex MixedFusedLayerNorm #262

Closed stas00 closed 2 years ago

stas00 commented 2 years ago

as noticed by @thomasw21 - switching embed layernorm to use MixedFusedLayerNorm for consistency with other layer norms.


Incidentally, this also fixes a bug with how torch.nn.LayerNorm was used until now.

the framework was putting LayerNorm into the wrong param group here: https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/dd06ea32e014d8db6cdaf5e6839071d6523ca83c/megatron/optimizer/__init__.py#L31-L45

it should have been in no_weight_decay_params but ended up in weight_decay_params because in this module LayerNorm is an alias for MixedFusedLayerNorm, so if isinstance(module_, LayerNorm) was False.

So if we want to use torch.nn.LayerNorm we have to change the code above to additionally check for or isinstance(module_, torch.nn.LayerNorm).