Closed stefan-it closed 5 years ago
XLM embeddings are now available in version 0.4.3 of Flair 🤗
Hi, could you add support for XLMRoberta? I tried to implement the class by myself. I just modified the class but I failed as the model cannot be saved by torch.save.
from transformers import (
XLMRobertaModel,
XLMRobertaTokenizer,
)
class XLMRoBERTaEmbeddings(TokenEmbeddings):
def __init__(
self,
pretrained_model_name_or_path: str = "xlm-roberta-large",
layers: str = "-1",
pooling_operation: str = "first",
use_scalar_mix: bool = False,
):
super().__init__()
self.tokenizer = XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path)
self.model = XLMRobertaModel.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
output_hidden_states=True,
)
self.name = pretrained_model_name_or_path
self.layers: List[int] = [int(layer) for layer in layers.split(",")]
self.pooling_operation = pooling_operation
self.use_scalar_mix = use_scalar_mix
self.static_embeddings = True
dummy_sentence: Sentence = Sentence()
dummy_sentence.add_token(Token("hello"))
embedded_dummy = self.embed(dummy_sentence)
self.__embedding_length: int = len(
embedded_dummy[0].get_token(1).get_embedding()
)
@property
def embedding_length(self) -> int:
return self.__embedding_length
def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
self.model.to(flair.device)
self.model.eval()
sentences = _get_transformer_sentence_embeddings(
sentences=sentences,
tokenizer=self.tokenizer,
model=self.model,
name=self.name,
layers=self.layers,
pooling_operation=self.pooling_operation,
use_scalar_mix=self.use_scalar_mix,
bos_token="<s>",
eos_token="</s>",
)
return sentences
def extra_repr(self):
return "model={}".format(self.name)
def __str__(self):
return self.name
The error:
2019-12-21 13:59:54,554 EPOCH 1 done: loss 80.1238 - lr 0.1000
2019-12-21 13:59:54,887 DEV : loss 13.243016242980957 - score 0.0
2019-12-21 13:59:54,888 BAD EPOCHS (no improvement): 0
Traceback (most recent call last):
File "train_with_teacher.py", line 108, in <module>
getattr(trainer,'train')(**train_config)
File "/home/wangxy/workspace/flair/flair/trainers/distillation_trainer.py", line 606, in train
self.model.save(base_path / "best-model.pt")
File "/home/wangxy/workspace/flair/flair/nn.py", line 67, in save
torch.save(model_state, str(model_file), pickle_protocol=4)
File "/home/wangxy/anaconda2/envs/parser/lib/python3.6/site-packages/torch/serialization.py", line 224, in save
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "/home/wangxy/anaconda2/envs/parser/lib/python3.6/site-packages/torch/serialization.py", line 149, in _with_file_like
return body(f)
File "/home/wangxy/anaconda2/envs/parser/lib/python3.6/site-packages/torch/serialization.py", line 224, in <lambda>
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "/home/wangxy/anaconda2/envs/parser/lib/python3.6/site-packages/torch/serialization.py", line 296, in _save
pickler.dump(obj)
TypeError: can't pickle SwigPyObject objects
@wangxinyu0922 Did you handled this problem? I met the same too when I try to custom flair to accept trained embedding XLMR from fairseq. TypeError: can't pickle SwigPyObject objects
. I can save model by using torch.save(model_state_dict()) but I think it's not the best option.
Fix and update to Transformers 2.3 is coming soon 😅
Hi,
recently, cross-lingual language model pretraining was proposed and released in the XLM library:
https://github.com/facebookresearch/XLM
pytorch-pretrained-BERT
will implement this XLM model (many thanks to @thomwolf) 😍Whenever this implementation is fully merged into
pytorch-pretrained-BERT
, we can add these newXLMEmbeddings
into flair.Meanwhile, it is really worth to have a look at the XLM paper here 😄