flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.7k stars 2.08k forks source link

[How do I get flair to recognize my custom embeddings class during model loading?]: #3452

Closed DonaldFeuz closed 4 weeks ago

DonaldFeuz commented 1 month ago

Question

Hi everyone. please i have created a class of custom embeddings, for training my model with flair. but when i run the training, it executes fine but at the end the model fails to load generating me an error every time.

my code: `from flair.embeddings import TokenEmbeddings, TransformerWordEmbeddings from flair.data import Sentence from torch import nn from typing import List

class BioBERTWithReprojection(TokenEmbeddings): embeddings_name = 'biobert_with_reprojection' def init(self, biobert_model_name: str, reprojection_size: int): super().init()

    self.name = 'biobert_with_reprojection'
    self.static_embeddings = False

    # Initialisation des embeddings BioBERT
    self.biobert_embeddings = TransformerWordEmbeddings(
        model=biobert_model_name,
        layers="-1",
        subtoken_pooling="first",
        fine_tune=True,
        use_context=True
    )

    # Création d'une couche de reprojection pour ajuster les dimensions des embeddings
    self.reprojection_layer = nn.Linear(self.biobert_embeddings.embedding_length, reprojection_size)
    self._embedding_length = reprojection_size

@property
def embedding_length(self) -> int:
    return self._embedding_length

def _add_embeddings_internal(self, sentences: List[Sentence]):
    # Appliquer d'abord les embeddings BioBERT
    self.biobert_embeddings._add_embeddings_internal(sentences)

    # Appliquer ensuite la couche de reprojection pour chaque token dans chaque phrase
    for sentence in sentences:
        for token in sentence:
            bio_embedding = token.get_embedding()
            reprojected_embedding = self.reprojection_layer(bio_embedding.unsqueeze(0))
            token.set_embedding(self.name, reprojected_embedding.squeeze(0))

def to_params(self):
    # Stocker les paramètres nécessaires pour pouvoir reconstruire cet objet d'embeddings
    return {
        'biobert_model_name': self.biobert_embeddings.model,
        'reprojection_size': self.reprojection_layer.out_features
    }

@classmethod
def from_params(cls, params):
    # Reconstruire l'objet d'embeddings à partir des paramètres
    return cls(params['biobert_model_name'], params['reprojection_size'])

def __str__(self):
    return self.name

def extra_repr(self):
    return f"biobert_model_name={self.biobert_embeddings.model}, reprojection_size={self.reprojection_layer.out_features}"`

the error:

2024-05-04 17:41:19,455 ---------------------------------------------------------------------------------------------------- 2024-05-04 17:41:19,462 Loading model from best epoch ...

KeyError Traceback (most recent call last) in <cell line: 3>() 1 trainer = ModelTrainer(tagger, combined_corpus) 2 ----> 3 trainer.train('resources/taggers/model-hunflair-donald', 4 learning_rate=0.1, 5 mini_batch_size=32,

7 frames /usr/local/lib/python3.10/dist-packages/flair/trainers/trainer.py in train(self, base_path, anneal_factor, patience, min_learning_rate, initial_extra_patience, anneal_with_restarts, learning_rate, decoder_learning_rate, mini_batch_size, eval_batch_size, mini_batch_chunk_size, max_epochs, optimizer, train_with_dev, train_with_test, reduce_transformer_vocab, main_evaluation_metric, monitor_test, monitor_train_sample, use_final_model_for_eval, gold_label_dictionary_for_eval, exclude_labels, sampler, shuffle, shuffle_first_epoch, embeddings_storage_mode, epoch, save_final_model, save_optimizer_state, save_model_each_k_epochs, create_file_logs, create_loss_file, write_weights, plugins, attach_default_scheduler, kwargs) 198 ]: 199 local_variables.pop(var) --> 200 return self.train_custom(local_variables, **kwargs) 201 202 def fine_tune(

/usr/local/lib/python3.10/dist-packages/flair/trainers/trainer.py in train_custom(self, base_path, learning_rate, decoder_learning_rate, mini_batch_size, eval_batch_size, mini_batch_chunk_size, max_epochs, optimizer, train_with_dev, train_with_test, max_grad_norm, reduce_transformer_vocab, main_evaluation_metric, monitor_test, monitor_train_sample, use_final_model_for_eval, gold_label_dictionary_for_eval, exclude_labels, sampler, shuffle, shuffle_first_epoch, embeddings_storage_mode, epoch, save_final_model, save_optimizer_state, save_model_each_k_epochs, create_file_logs, create_loss_file, write_weights, use_amp, plugins, **kwargs) 779 if (base_path / "best-model.pt").exists(): 780 log.info("Loading model from best epoch ...") --> 781 self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) 782 else: 783 log.info("Testing using last state of model ...")

/usr/local/lib/python3.10/dist-packages/flair/models/sequence_tagger_model.py in load(cls, model_path) 1034 from typing import cast 1035 -> 1036 return cast("SequenceTagger", super().load(model_path=model_path))

/usr/local/lib/python3.10/dist-packages/flair/nn/model.py in load(cls, model_path) 553 from typing import cast 554 --> 555 return cast("Classifier", super().load(model_path=model_path)) 556 557

/usr/local/lib/python3.10/dist-packages/flair/nn/model.py in load(cls, model_path) 184 state.pop("cls") 185 --> 186 model = cls._init_model_with_state_dict(state) 187 188 if "model_card" in state:

/usr/local/lib/python3.10/dist-packages/flair/models/sequence_tagger_model.py in _init_model_with_state_dict(cls, state, **kwargs) 620 del state["state_dict"]["transitions"] 621 --> 622 return super()._init_model_with_state_dict( 623 state, 624 embeddings=state.get("embeddings"),

/usr/local/lib/python3.10/dist-packages/flair/nn/model.py in _init_model_with_state_dict(cls, state, **kwargs) 97 embeddings = kwargs.pop("embeddings") 98 if isinstance(embeddings, dict): ---> 99 embeddings = load_embeddings(embeddings) 100 kwargs["embeddings"] = embeddings 101

/usr/local/lib/python3.10/dist-packages/flair/embeddings/base.py in load_embeddings(params) 230 def load_embeddings(params: Dict[str, Any]) -> Embeddings: 231 cls_name = params.pop("cls") --> 232 cls = EMBEDDING_CLASSES[cls_name] 233 return cls.load_embedding(params)

KeyError: 'biobert_with_reprojection'

this problem has been haunting me for a long time please help. Thank you

helpmefindaname commented 1 month ago

Hi @DonaldFeuz

I see you already figured out the from_params and to_params method. What is missing is the @register_embeddings decorator.

from flair.embeddings.base import register_embeddings

@register_embeddings
class MyEmbeddings(TokenEmbeddings):
   ....

that way, the lookup will know to also consider this embeddings class when loading embeddings.

Keep in mind, that the Moduel that defines the Embedding class needs to be loaded beforehands, e.g. you need to import it. I usually use the __init__.py to do that:


import my_project.my_embeddings

__all__ = ["my_embeddings"]
``

Then I can be sure that any file in `my_project` has the embeddings already registered and can load the model the correct way
DonaldFeuz commented 1 month ago

Thank you very much @helpmefindaname
My problem has been successfully solved thanks to your solution.