microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.25k stars 90 forks source link

Multiple nn.Linear layers #4

Closed windspirit95 closed 2 years ago

windspirit95 commented 2 years ago

Hi, Your project is really interesting, so I am learning how to apply it to some specific models. For example, the model has multiple nn.Linear layers like in wav2vec 2.0 (self.post_extract_proj, self.project_q, self.project_inp, self.target_glu, self.final_proj), should I replace all these layers to MuReadout?

class Wav2Vec2Model(BaseFairseqModel):
    def __init__(self, cfg: Wav2Vec2Config):
        super().__init__()
        self.cfg = cfg

        feature_enc_layers = eval(cfg.conv_feature_layers)
        self.embed = feature_enc_layers[-1][0]

        self.feature_extractor = ConvFeatureExtractionModel(
            conv_layers=feature_enc_layers,
            dropout=0.0,
            mode=cfg.extractor_mode,
            conv_bias=cfg.conv_bias,
        )

        self.post_extract_proj = (
            nn.Linear(self.embed, cfg.encoder_embed_dim)
            if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input
            else None
        )

        self.mask_prob = cfg.mask_prob
        self.mask_selection = cfg.mask_selection
        self.mask_other = cfg.mask_other
        self.mask_length = cfg.mask_length
        self.no_mask_overlap = cfg.no_mask_overlap
        self.mask_min_space = cfg.mask_min_space

        self.mask_channel_prob = cfg.mask_channel_prob
        self.mask_channel_before = cfg.mask_channel_before
        self.mask_channel_selection = cfg.mask_channel_selection
        self.mask_channel_other = cfg.mask_channel_other
        self.mask_channel_length = cfg.mask_channel_length
        self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
        self.mask_channel_min_space = cfg.mask_channel_min_space

        self.dropout_input = nn.Dropout(cfg.dropout_input)
        self.dropout_features = nn.Dropout(cfg.dropout_features)

        self.feature_grad_mult = cfg.feature_grad_mult

        self.quantizer = None
        self.input_quantizer = None

        self.n_negatives = cfg.num_negatives
        self.cross_sample_negatives = cfg.cross_sample_negatives
        self.codebook_negatives = cfg.codebook_negatives
        self.negatives_from_everywhere = cfg.negatives_from_everywhere

        self.logit_temp = cfg.logit_temp

        final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim

        if cfg.quantize_targets:
            vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim
            self.quantizer = GumbelVectorQuantizer(
                dim=self.embed,
                num_vars=cfg.latent_vars,
                temp=cfg.latent_temp,
                groups=cfg.latent_groups,
                combine_groups=False,
                vq_dim=vq_dim,
                time_first=True,
                weight_proj_depth=cfg.quantizer_depth,
                weight_proj_factor=cfg.quantizer_factor,
            )
            self.project_q = nn.Linear(vq_dim, final_dim)
        else:
            self.project_q = nn.Linear(self.embed, final_dim)

        if cfg.quantize_input:
            if cfg.same_quantizer and self.quantizer is not None:
                vq_dim = final_dim
                self.input_quantizer = self.quantizer
            else:
                vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim
                self.input_quantizer = GumbelVectorQuantizer(
                    dim=self.embed,
                    num_vars=cfg.latent_vars,
                    temp=cfg.latent_temp,
                    groups=cfg.latent_groups,
                    combine_groups=False,
                    vq_dim=vq_dim,
                    time_first=True,
                    weight_proj_depth=cfg.quantizer_depth,
                    weight_proj_factor=cfg.quantizer_factor,
                )
            self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim)

        self.mask_emb = nn.Parameter(
            torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
        )

        self.encoder = TransformerEncoder(cfg)
        self.layer_norm = LayerNorm(self.embed)

        self.target_glu = None
        if cfg.target_glu:
            self.target_glu = nn.Sequential(
                nn.Linear(final_dim, final_dim * 2), nn.GLU()
            )

        self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)

Thank you! ^^

thegregyang commented 2 years ago

Hi @windspirit95, thanks for your interest!

As a general rule of thumb, you only need to replace nn.Linear that goes from the model's hidden representation space to the "label" space. These layers are typically the last ones you apply before apply the loss function. Usually, this is easy to tell from the model's forward function. However, here, from a cursory read of the model __init__, it seems that the nn.Linear that maps from *_embed_dim to final_dim are the ones that need to replace by mup.MuReadout: self.final_proj and self.project_q. Though here there are a few lines that seem to mix hidden dimension with output dimension, e.g.

final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim

and it's not clear to me why this is done.

windspirit95 commented 2 years ago

Hi @thegregyang. Thanks for your reply. The example I took from: https://github.com/pytorch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py, from line 294 Seem this model looks a bit like a transformer example, isn't it? But it quite difficult to apply muP in this model :)

thegregyang commented 2 years ago

Looking at the source code, it seems that my initial assessment is correct. You just need to replace self.final_proj and self.project_q. You can double check this by running coord check and ensuring that the curves are flat.

thegregyang commented 2 years ago

I'll close this issue for now since it seems like we have arrived at an answer. But feel free to re-open if there are further problems.