flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.94k stars 2.11k forks source link

Add support for XLM #840

Closed stefan-it closed 5 years ago

stefan-it commented 5 years ago

Hi,

recently, cross-lingual language model pretraining was proposed and released in the XLM library:

https://github.com/facebookresearch/XLM

pytorch-pretrained-BERT will implement this XLM model (many thanks to @thomwolf) 😍

Whenever this implementation is fully merged into pytorch-pretrained-BERT, we can add these new XLMEmbeddings into flair.

Meanwhile, it is really worth to have a look at the XLM paper here 😄

stefan-it commented 5 years ago

XLM embeddings are now available in version 0.4.3 of Flair 🤗

wangxinyu0922 commented 4 years ago

Hi, could you add support for XLMRoberta? I tried to implement the class by myself. I just modified the class but I failed as the model cannot be saved by torch.save.

from transformers import (
    XLMRobertaModel,
    XLMRobertaTokenizer,
    )
class XLMRoBERTaEmbeddings(TokenEmbeddings):
    def __init__(
        self,
        pretrained_model_name_or_path: str = "xlm-roberta-large",
        layers: str = "-1",
        pooling_operation: str = "first",
        use_scalar_mix: bool = False,
    ):

        super().__init__()
        self.tokenizer = XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path)
        self.model = XLMRobertaModel.from_pretrained(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            output_hidden_states=True,
        )
        self.name = pretrained_model_name_or_path
        self.layers: List[int] = [int(layer) for layer in layers.split(",")]
        self.pooling_operation = pooling_operation
        self.use_scalar_mix = use_scalar_mix
        self.static_embeddings = True

        dummy_sentence: Sentence = Sentence()
        dummy_sentence.add_token(Token("hello"))
        embedded_dummy = self.embed(dummy_sentence)
        self.__embedding_length: int = len(
            embedded_dummy[0].get_token(1).get_embedding()
        )

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
        self.model.to(flair.device)
        self.model.eval()

        sentences = _get_transformer_sentence_embeddings(
            sentences=sentences,
            tokenizer=self.tokenizer,
            model=self.model,
            name=self.name,
            layers=self.layers,
            pooling_operation=self.pooling_operation,
            use_scalar_mix=self.use_scalar_mix,
            bos_token="<s>",
            eos_token="</s>",
        )

        return sentences

    def extra_repr(self):
        return "model={}".format(self.name)

    def __str__(self):
        return self.name

The error:

2019-12-21 13:59:54,554 EPOCH 1 done: loss 80.1238 - lr 0.1000
2019-12-21 13:59:54,887 DEV : loss 13.243016242980957 - score 0.0
2019-12-21 13:59:54,888 BAD EPOCHS (no improvement): 0
Traceback (most recent call last):
  File "train_with_teacher.py", line 108, in <module>
    getattr(trainer,'train')(**train_config)
  File "/home/wangxy/workspace/flair/flair/trainers/distillation_trainer.py", line 606, in train
    self.model.save(base_path / "best-model.pt")
  File "/home/wangxy/workspace/flair/flair/nn.py", line 67, in save
    torch.save(model_state, str(model_file), pickle_protocol=4)
  File "/home/wangxy/anaconda2/envs/parser/lib/python3.6/site-packages/torch/serialization.py", line 224, in save
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
  File "/home/wangxy/anaconda2/envs/parser/lib/python3.6/site-packages/torch/serialization.py", line 149, in _with_file_like
    return body(f)
  File "/home/wangxy/anaconda2/envs/parser/lib/python3.6/site-packages/torch/serialization.py", line 224, in <lambda>
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
  File "/home/wangxy/anaconda2/envs/parser/lib/python3.6/site-packages/torch/serialization.py", line 296, in _save
    pickler.dump(obj)
TypeError: can't pickle SwigPyObject objects
nguyenvulebinh commented 4 years ago

@wangxinyu0922 Did you handled this problem? I met the same too when I try to custom flair to accept trained embedding XLMR from fairseq. TypeError: can't pickle SwigPyObject objects. I can save model by using torch.save(model_state_dict()) but I think it's not the best option.

stefan-it commented 4 years ago

Fix and update to Transformers 2.3 is coming soon 😅