Open ofivite opened 12 months ago
Thanks for pointing this out! Your analysis seems correct.
A simple fix is to add self._has_rescaled_params = True
to the constructor of MuSharedReadout
, so we don't trigger rescaling. I'll do that after making sure it doesn't have unintended consequences.
The vanishing preactivation in the last row should be the final logits. The GP behavior at init follows CLT yet the scaling accounts for LLN. A way to get rid of it is to initialize the shared embedding layer to zero, which is okay as long as there the embedded input is not always zero (e.g., through a non-zero positional embedding). You should produce flat curves that way.
Thank you @edwardjhu for your answer and suggestions! Initialising shared embeddings to zero is a good idea, I will try that out in the Megatron setup and see if the curves look flat there :)
Hi! I've been looking into the integration of
muP
into theMegatron-LM
setup and I was wondering about the_rescale_parameters()
method ofMuReadout
in case of shared (tied) input/output embeddings. Specifically, in the Transformer example I am not really sure that it is in line with the suggested embedding initialisation (i.e., constant) from the paper.Currently, in the example:
encoder
is initialised fromN(0,1)
<- defaultnn.Embedding
init https://github.com/microsoft/mup/blob/a33ea802bcef1d7744057e34ff00d1a5d7e3d7c4/examples/Transformer/model.py#L93decoder
is firstly initialised withinMuSharedReadout
fromU(-1/sqrt(fan_in), 1/sqrt(fan_in))
<- defaultnn.Linear
init https://github.com/microsoft/mup/blob/a33ea802bcef1d7744057e34ff00d1a5d7e3d7c4/mup/layer.py#L67encoder
(the next line 68) -> they becomeN(0,1)
initset_base_shapes()
is called, both encoder and decoder weights will be rescaled within_rescale_parameters()
by*= self.width_mult()**0.5
-> which makes them initialised fromN(0, sqrt(d/d_0))
and so scale with width.However, in the muP paper it is suggested to initialise them as constants to be muP compatible. It also should be mentioned that in the untied case, the output embeddings are set to 0, so
_rescale_parameters()
doesn't have an effect and things are consistent with the paper.Below I also attach the coordinate check plots for the Transformer example for untied, tied+rescaling (current implementation) and tied+no rescaling (
_rescale_parameters()
disabled), respectively. One can see that for untied the norms are nicely flat, for tied+rescaling some layers have growing activations, and for tied+no rescaling one layer has a vanishing trend.So I was wondering if
_rescale_parameters()
should be disabled for the tied embedding scenario to keep the init constant, assuming the inheritance of N(0,1) initialisation innn.Embedding()
?