Closed DonaldFeuz closed 4 weeks ago
Hi @DonaldFeuz
I see you already figured out the from_params
and to_params
method. What is missing is the @register_embeddings
decorator.
from flair.embeddings.base import register_embeddings
@register_embeddings
class MyEmbeddings(TokenEmbeddings):
....
that way, the lookup will know to also consider this embeddings class when loading embeddings.
Keep in mind, that the Moduel that defines the Embedding class needs to be loaded beforehands, e.g. you need to import it. I usually use the __init__.py
to do that:
import my_project.my_embeddings
__all__ = ["my_embeddings"]
``
Then I can be sure that any file in `my_project` has the embeddings already registered and can load the model the correct way
Thank you very much @helpmefindaname
My problem has been successfully solved thanks to your solution.
Question
Hi everyone. please i have created a class of custom embeddings, for training my model with flair. but when i run the training, it executes fine but at the end the model fails to load generating me an error every time.
my code: `from flair.embeddings import TokenEmbeddings, TransformerWordEmbeddings from flair.data import Sentence from torch import nn from typing import List
class BioBERTWithReprojection(TokenEmbeddings): embeddings_name = 'biobert_with_reprojection' def init(self, biobert_model_name: str, reprojection_size: int): super().init()
the error:
2024-05-04 17:41:19,455 ---------------------------------------------------------------------------------------------------- 2024-05-04 17:41:19,462 Loading model from best epoch ...
KeyError Traceback (most recent call last) in <cell line: 3>()
1 trainer = ModelTrainer(tagger, combined_corpus)
2
----> 3 trainer.train('resources/taggers/model-hunflair-donald',
4 learning_rate=0.1,
5 mini_batch_size=32,
7 frames /usr/local/lib/python3.10/dist-packages/flair/trainers/trainer.py in train(self, base_path, anneal_factor, patience, min_learning_rate, initial_extra_patience, anneal_with_restarts, learning_rate, decoder_learning_rate, mini_batch_size, eval_batch_size, mini_batch_chunk_size, max_epochs, optimizer, train_with_dev, train_with_test, reduce_transformer_vocab, main_evaluation_metric, monitor_test, monitor_train_sample, use_final_model_for_eval, gold_label_dictionary_for_eval, exclude_labels, sampler, shuffle, shuffle_first_epoch, embeddings_storage_mode, epoch, save_final_model, save_optimizer_state, save_model_each_k_epochs, create_file_logs, create_loss_file, write_weights, plugins, attach_default_scheduler, kwargs) 198 ]: 199 local_variables.pop(var) --> 200 return self.train_custom(local_variables, **kwargs) 201 202 def fine_tune(
/usr/local/lib/python3.10/dist-packages/flair/trainers/trainer.py in train_custom(self, base_path, learning_rate, decoder_learning_rate, mini_batch_size, eval_batch_size, mini_batch_chunk_size, max_epochs, optimizer, train_with_dev, train_with_test, max_grad_norm, reduce_transformer_vocab, main_evaluation_metric, monitor_test, monitor_train_sample, use_final_model_for_eval, gold_label_dictionary_for_eval, exclude_labels, sampler, shuffle, shuffle_first_epoch, embeddings_storage_mode, epoch, save_final_model, save_optimizer_state, save_model_each_k_epochs, create_file_logs, create_loss_file, write_weights, use_amp, plugins, **kwargs) 779 if (base_path / "best-model.pt").exists(): 780 log.info("Loading model from best epoch ...") --> 781 self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) 782 else: 783 log.info("Testing using last state of model ...")
/usr/local/lib/python3.10/dist-packages/flair/models/sequence_tagger_model.py in load(cls, model_path) 1034 from typing import cast 1035 -> 1036 return cast("SequenceTagger", super().load(model_path=model_path))
/usr/local/lib/python3.10/dist-packages/flair/nn/model.py in load(cls, model_path) 553 from typing import cast 554 --> 555 return cast("Classifier", super().load(model_path=model_path)) 556 557
/usr/local/lib/python3.10/dist-packages/flair/nn/model.py in load(cls, model_path) 184 state.pop("cls") 185 --> 186 model = cls._init_model_with_state_dict(state) 187 188 if "model_card" in state:
/usr/local/lib/python3.10/dist-packages/flair/models/sequence_tagger_model.py in _init_model_with_state_dict(cls, state, **kwargs) 620 del state["state_dict"]["transitions"] 621 --> 622 return super()._init_model_with_state_dict( 623 state, 624 embeddings=state.get("embeddings"),
/usr/local/lib/python3.10/dist-packages/flair/nn/model.py in _init_model_with_state_dict(cls, state, **kwargs) 97 embeddings = kwargs.pop("embeddings") 98 if isinstance(embeddings, dict): ---> 99 embeddings = load_embeddings(embeddings) 100 kwargs["embeddings"] = embeddings 101
/usr/local/lib/python3.10/dist-packages/flair/embeddings/base.py in load_embeddings(params) 230 def load_embeddings(params: Dict[str, Any]) -> Embeddings: 231 cls_name = params.pop("cls") --> 232 cls = EMBEDDING_CLASSES[cls_name] 233 return cls.load_embedding(params)
KeyError: 'biobert_with_reprojection'
this problem has been haunting me for a long time please help. Thank you