microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.24k stars 88 forks source link

_rescale_parameters() inconsistent with the paper for the tied embedding scenario? #55

Open ofivite opened 12 months ago

ofivite commented 12 months ago

Hi! I've been looking into the integration of muP into the Megatron-LM setup and I was wondering about the _rescale_parameters() method of MuReadout in case of shared (tied) input/output embeddings. Specifically, in the Transformer example I am not really sure that it is in line with the suggested embedding initialisation (i.e., constant) from the paper.

Currently, in the example:

However, in the muP paper it is suggested to initialise them as constants to be muP compatible. It also should be mentioned that in the untied case, the output embeddings are set to 0, so _rescale_parameters() doesn't have an effect and things are consistent with the paper.

Below I also attach the coordinate check plots for the Transformer example for untied, tied+rescaling (current implementation) and tied+no rescaling (_rescale_parameters() disabled), respectively. One can see that for untied the norms are nicely flat, for tied+rescaling some layers have growing activations, and for tied+no rescaling one layer has a vanishing trend.

So I was wondering if _rescale_parameters() should be disabled for the tied embedding scenario to keep the init constant, assuming the inheritance of N(0,1) initialisation in nn.Embedding()?

μp_trsfmr_adam_coord μp_trsfmr_adam_coord_tied μp_trsfmr_adam_coord_tied_fix

edwardjhu commented 11 months ago

Thanks for pointing this out! Your analysis seems correct.

A simple fix is to add self._has_rescaled_params = True to the constructor of MuSharedReadout, so we don't trigger rescaling. I'll do that after making sure it doesn't have unintended consequences.

The vanishing preactivation in the last row should be the final logits. The GP behavior at init follows CLT yet the scaling accounts for LLN. A way to get rid of it is to initialize the shared embedding layer to zero, which is okay as long as there the embedded input is not always zero (e.g., through a non-zero positional embedding). You should produce flat curves that way.

ofivite commented 11 months ago

Thank you @edwardjhu for your answer and suggestions! Initialising shared embeddings to zero is a good idea, I will try that out in the Megatron setup and see if the curves look flat there :)