flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.88k stars 2.1k forks source link

SequenceTagger.relearn_embeddings meaning (embedding2nn layer) #1433

Closed falcaopetri closed 4 years ago

falcaopetri commented 4 years ago

The flag SequenceTagger.relearn_embeddings is always set to true and is used to add a Linear layer called embedding2nn: https://github.com/flairNLP/flair/blob/4ce32c774b4dc5a8bfc0559441a1c8da4f06131d/flair/models/sequence_tagger_model.py#L154-L157

The only info I found about it is at the comment https://github.com/flairNLP/flair/commit/0056b2613ee9c169cb9c23e5e84fbcca180dde77#r31462803, where @alanakbik said:

Hi, this is a map on top of the embedding layer. This is to fine-tune traditional word embeddings -our implementation does not initialize a one-hot embedding layer, but rather retrieves word embeddings from gensim keyedvectors so the only way to fine-tune before being passed to other layers is to do this remapping.

I'm not sure if I understand its meaning. Anyway, what if I do not want to fine-tune the pre-trained embeddings?

To contextualize, I stumbled in this because I noticed that embedding2nn layer adds too many trainable parameters (in the order of embedding_length*embedding_length).

I have two use cases:

Do I really need to fine-tune the ELMoEmbeddings?

alanakbik commented 4 years ago

@falcaopetri its possible to disable the reprojection in the current implementation, like this:

# create tagger with your normal parameters
tagger = SequenceTagger([.....])

# then set relearn embeddings to False
tagger.relearn_embeddings = False

But we should probably do a PR to add this parameter to the constructor. We haven't tested leaving it out too much, so we would be interested to hear if that makes a difference in your case.

falcaopetri commented 4 years ago

Thanks for the temporary suggestion, @alanakbik. Unfortunately, I might not manage to test it (I'm working on a project which will end soon).

My first thought was that the way embedding2nn layer is being used probably makes it harder to investigate the number of trainable parameters in a given model (see #1302), i.e., even if I set tagger.relearn_embeddings = False, embedding2nn's parameters will still count as trainable parameters.

Anyway, as I understood, this reprojection layer is doing a fine-tuning to whatever embeddings are being used. For embeddings that are already being trained (e.g., an OneHotEmbeddings), this seems unnecessary.

For embeddings like ELMo (and maybe Flair?), which are feature-based, this probably improves the metrics, but seems counterproductive, since their premise is to train the embeddings once, and reuse it directly in a downstream model (Am I missing something?).

I think that ELMo's usage in this case would be to learn the weights to combine its layers for a given downstream task (#1264). allenai/allennlp#2904 discusses fine-tuning an ELMo model. But as I expected, it's about unsupervised fine-tuning to a new specific corpus. Then the embeddings are used in, e.g., a classification model without being retrained.

BERT's approach on the other hand is to be fine-tuned. As I understood by https://github.com/flairNLP/flair/issues/1302#issuecomment-570911062, BERTEmbeddings are not being trained. I'm not sure why it was implemented like that here in Flair (problem with using an external lib implementation?). I guess that this reprojection layer is kinda doing a fine-tuning, but it's probably far away from fine-tuning the whole BERT model (Am I missing something again?)

(just to inform you, I'm no expert in any of this :P )

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

alanakbik commented 4 years ago

We recently merged a PR that lets you disable the reprojection layer by setting reproject_embeddings=False:

tagger = SequenceTagger(
    hidden_size=256,
    embeddings=embeddings,
    tag_dictionary=tag_dictionary,
    tag_type=label_type,
    reproject_embeddings=False,
)