it should have been in no_weight_decay_params but ended up in weight_decay_params because in this module LayerNorm is an alias for MixedFusedLayerNorm, so if isinstance(module_, LayerNorm) was False.
So if we want to use torch.nn.LayerNorm we have to change the code above to additionally check for or isinstance(module_, torch.nn.LayerNorm).
as noticed by @thomasw21 - switching embed layernorm to use
MixedFusedLayerNorm
for consistency with other layer norms.Incidentally, this also fixes a bug with how
torch.nn.LayerNorm
was used until now.the framework was putting
LayerNorm
into the wrong param group here: https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/dd06ea32e014d8db6cdaf5e6839071d6523ca83c/megatron/optimizer/__init__.py#L31-L45it should have been in
no_weight_decay_params
but ended up inweight_decay_params
because in this moduleLayerNorm
is an alias forMixedFusedLayerNorm
, so ifisinstance(module_, LayerNorm)
wasFalse
.So if we want to use
torch.nn.LayerNorm
we have to change the code above to additionally check foror isinstance(module_, torch.nn.LayerNorm).