allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.75k stars 2.25k forks source link

Re-initialize some layers of a PretrainedTransformerEmbedder #5491

Closed JohnGiorgi closed 2 years ago

JohnGiorgi commented 2 years ago

Is your feature request related to a problem? Please describe.

The paper Revisiting Few-sample BERT Fine-tuning (published at ICLR 2021) demonstrated that re-initializing the last few layers of a pretrained transformer before fine-tuning can reduce the variance between re-runs, speed up convergence and improve final task performance, nicely summarized in their figures:

image image

The intuition is that some of the final layers may be over-specified to the pretraining objective(s) and therefore the pretrained weights can provide a bad initialization for downstream tasks.

It would be nice if re-initializing the the weights of certain layers in a pretrained transformer model was easy to do with AllenNLP.

Describe the solution you'd like

Ideally, you could easily specify which layers to re-initialize in a PretrainedTransformerEmbedder, something like:

from allennlp.modules.token_embedders import PretrainedTransformerEmbedder

# Re-initialize the last 2 layers
embedder = PretrainedTransformerEmbedder(model_name="bert-base-uncased", reinit_layers=2)
# AND/OR, provide your own layer indices
embedder = PretrainedTransformerEmbedder(model_name="bert-base-uncased", reinit_layers=[10, 11])

The __init__ of PretrainedTransformerEmbedder would take care of correctly re-initializing the specified layers for the given model_name.

Describe alternatives you've considered

You could achieve this right now with the AllenNLP initializers, but this would require:

  1. Writing regex to target each layer, which gets messy if you want to initialize some weights differently than others (like the weights/biases of LayerNorm vs FeedFoward).
  2. Knowing how the model was initialized in the first place. E.g. BERT inits parameters using a truncated normal distribution with mean=0 and std=0.02. Ideally, the user wouldn't have to know/specify this.

Additional context

I've drafted a solution that works (but requires more testing). Essentially, we add a new parameter to PretrainedTransformerEmbedder, reinit_layers, which can be an integer or list of integers. In __init__, we re-initialize as follows:

self._reinit_layers = reinit_layers
if self._reinit_layers and load_weights:
    num_layers = len(self.transformer_model.encoder.layer)
    if isinstance(reinit_layers, int):
        self._reinit_layers = list(range(num_layers - self._reinit_layers, num_layers))
    if any(layer_idx > num_layers for layer_idx in self._reinit_layers):
        raise ValueError(
            f"A layer index in reinit_layers ({self._reinit_layers}) is larger than the"
            f" maximum layer index {num_layers - 1}."
        )
    for layer_idx in self._reinit_layers:
        self.transformer_model.encoder.layer[layer_idx].apply(
            self.transformer_model._init_weights
        )

I sanity-checked it by testing that the weights of the specified layers are indeed re-initialized. I also trained a model with re-initialized layers on my own task and got a non-negligible performance boost.

If the AllenNLP maintainers think this would be a good addition I would be happy to open a PR!

epwalsh commented 2 years ago

Hey @JohnGiorgi, I do think this would be a good addition. Feel free to ping me when you start the PR!

github-actions[bot] commented 2 years ago

This issue is being closed due to lack of activity. If you think it still needs to be addressed, please comment on this thread 👇

JohnGiorgi commented 2 years ago

Oops, still working on #5505 so I think it makes sense to keep this open!

epwalsh commented 2 years ago

Unfortunately there's no easy way to check if an issue has an open linked pull request from the GitHub API, which should be a sufficient condition to keep the issue open 😕