flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.7k stars 2.08k forks source link

[Bug]: Shared layers in multi-task model are no longer shared after loading the model from a checkpoint #3446

Open chelseagzr opened 2 months ago

chelseagzr commented 2 months ago

Describe the bug

Thank you for developing and maintaining this invaluable module!

We would like to learn a multi-task model on two NER tasks by sharing a transformer word embedding. We fine-tuned the model for several epochs and saved the checkpoint every epoch by specifying save_model_each_k_epochs=1 when calling the function fine_tune. Now, assume we would like to continue fine-tuning from a previously saved checkpoint. We loaded the model from a checkpoint by calling the function MultitaskModel.load. However, the transformer word embedding is no longer shared between the two tasks.

To Reproduce

%pip install scipy==1.10.1 transformers torch==2.0 flair==0.13.1 

from flair.datasets import NER_CHINESE_WEIBO, NER_ENGLISH_PERSON
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger, MultitaskModel 
from flair.trainers import ModelTrainer
from flair.nn.multitask import make_multitask_model_and_corpus

# 1. get the corpus
corpus_1 = NER_CHINESE_WEIBO()
print(corpus_1)
corpus_2 = NER_ENGLISH_PERSON()
print(corpus_2)

# 2. what label do we want to predict?
label_type = 'ner'

# 3. make the label dictionary from the corpus
label_dict_1 = corpus_1.make_label_dictionary(label_type=label_type, add_unk=False)
print(label_dict_1)
label_dict_2 = corpus_2.make_label_dictionary(label_type=label_type, add_unk=False)
print(label_dict_2)

# 4. initialize fine-tuneable transformer embeddings WITH document context
shared_embeddings = TransformerWordEmbeddings(model='xlm-roberta-base',
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True
)

# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger_1 = SequenceTagger(hidden_size=256,
                        embeddings=shared_embeddings,
                        tag_dictionary=label_dict_1,
                        tag_type=label_type,
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False
)
tagger_2 = SequenceTagger(hidden_size=256,
                        embeddings=shared_embeddings,
                        tag_dictionary=label_dict_2,
                        tag_type=label_type,
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False
)

# 6. initialize trainer
multitask_model, multicorpus = make_multitask_model_and_corpus(
    [
        (tagger_1, corpus_1),
        (tagger_2, corpus_2),
    ]
)
# the embedding layer of tagger_1 and tagger_2 are shared (one copy of embedding layer)
trainer = ModelTrainer(multitask_model, multicorpus)

# 7. run fine-tuning
trainer.fine_tune('resources/taggers/sota-ner-flert',
                  learning_rate=5.0e-6,
                  max_epochs=1,
                  mini_batch_size=4,
                  save_model_each_k_epochs=1
)

# 8. load from saved checkpoint
multitask_model = MultitaskModel.load('resources/taggers/sota-ner-flert/model_epoch_1.pt')
# the embedding layer of tagger_1 and tagger_2 are NOT shared now (two copies of embedding layer). The two copies have the same values after loading from the checkpoint, but they will have different values if we continue fine-tuning.

# 9. continue fine-tuning
trainer = ModelTrainer(multitask_model, multicorpus)
trainer.fine_tune('resources/taggers/sota-ner-flert',
                  learning_rate=5.0e-6,
                  epoch=1,
                  max_epochs=2,
                  mini_batch_size=4,
                  save_model_each_k_epochs=1
)

Expected behavior

Shared layers between tasks are still shared after loading from a checkpoint.

Logs and Stack traces

No response

Screenshots

No response

Additional Context

No response

Environment

Versions:

Flair

0.13.1

Pytorch

2.0.0+cu117

Transformers

4.40.0

GPU

True

chelseagzr commented 1 month ago

For this specific example, I think either of the following two method works. (Please let me know if you see any problem in these two methods.) I was wondering if this bug can be fixed inside the MultitaskModel.load method for any multitask model. Thank you!

Method 1: assign the embedding layers of one task to the other tasks

# 8. load from saved checkpoint
multitask_model = MultitaskModel.load('resources/taggers/sota-ner-flert/model_epoch_1.pt')
# the embedding layer of tagger_1 and tagger_2 are NOT shared now (two copies of embedding layer). The two copies have the same values after loading from the checkpoint, but they will have different values if we continue fine-tuning.

# 9. assign the embedding layers of Task_0 to the embedding layers of Task_1
multitask_model.tasks['Task_1'].embeddings = multitask_model.tasks['Task_0'].embeddings

# 10. continue fine-tuning
trainer = ModelTrainer(multitask_model, multicorpus)
trainer.fine_tune('resources/taggers/sota-ner-flert',
                  learning_rate=5.0e-6,
                  epoch=1,
                  max_epochs=2,
                  mini_batch_size=4,
                  save_model_each_k_epochs=1
)

Method 2: create each component in the same way they were created initially and load their state dicts separately.

import torch
import flair
loaded_model = torch.load('resources/taggers/sota-ner-flert/model_epoch_1.pt', map_location=flair.device)
model_states = loaded_model["model_states"]

# 8. create and load shared embedding
shared_embeddings = TransformerWordEmbeddings(model='xlm-roberta-base',
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True
)
embedding_state_dict = {}
prefix_of_embedding_layers = "embeddings."
for key in model_states["Task_0"]["state_dict"]:
    if key.startswith(prefix_of_embedding_layers):
        new_key = key[len(prefix_of_embedding_layers):] # The prefix need to be dropped
        embedding_state_dict[new_key] = model_states["Task_0"]["state_dict"][key]
shared_embeddings.load_state_dict(state_dict=embedding_state_dict, strict=True)

# 9. create and load SequenceTagger for "Task_0" (excluding shared embedding)
tagger_1 = SequenceTagger(hidden_size=256,
                        embeddings=shared_embeddings,
                        tag_dictionary=label_dict_1,
                        tag_type=label_type,
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False
)
state_dict = {}
for key in model_states["Task_0"]["state_dict"]:
    if not key.startswith(prefix_of_embedding_layers):
        state_dict[key] = model_states["Task_0"]["state_dict"][key]
tagger_1.load_state_dict(state_dict=state_dict, strict=False)

# 10. create and load SequenceTagger for "Task_1" (excluding shared embedding)
tagger_2 = SequenceTagger(hidden_size=256,
                        embeddings=shared_embeddings,
                        tag_dictionary=label_dict_2,
                        tag_type=label_type,
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False
)
state_dict = {}
for key in model_states["Task_1"]["state_dict"]:
    if not key.startswith(prefix_of_embedding_layers):
        state_dict[key] = model_states["Task_1"]["state_dict"][key]
tagger_2.load_state_dict(state_dict=state_dict, strict=False)

# 11. initialize trainer
multitask_model, multicorpus = make_multitask_model_and_corpus(
    [
        (tagger_1, corpus_1),
        (tagger_2, corpus_2),
    ]
)
# the embedding layer of tagger_1 and tagger_2 are shared (one copy of embedding layer)
trainer = ModelTrainer(multitask_model, multicorpus)

# 12. run fine-tuning
trainer.fine_tune('resources/taggers/sota-ner-flert',
                  learning_rate=5.0e-6,
                  epoch=1,
                  max_epochs=2,
                  mini_batch_size=4,
                  save_model_each_k_epochs=1
)