flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.88k stars 2.1k forks source link

Pytorch Serialization error when training NER tagger with de-fasttext word embeddings #171

Closed mhham closed 5 years ago

mhham commented 5 years ago

When training an NER sequence tagger with WordEmbeddings('de-fasttext') I get a torch serialization error, right after the first epoch.

Code:

from flair.data_fetcher import NLPTask
from flair.embeddings import CharacterEmbeddings, CharLMEmbeddings, TokenEmbeddings, WordEmbeddings, StackedEmbeddings
from typing import List

# 1. get the corpus
corpus = downsampled_corpus 

# 2. what tag do we want to predict?
tag_type = 'ner'

# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
print(tag_dictionary.idx2item)

# 4. initialize embeddings
embedding_types: List[TokenEmbeddings] = [

    WordEmbeddings('de-fasttext') #Defined in the WordEmbeddings class

    # comment in this line to use character embeddings
    # CharacterEmbeddings(),

    # comment in these lines to use contextual string embeddings
    #CharLMEmbeddings('news-forward'),
    #CharLMEmbeddings('news-backward'),
]

embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)

# 5. initialize sequence tagger
from flair.models import SequenceTagger

tagger: SequenceTagger = SequenceTagger(hidden_size=256,
                                        embeddings=embeddings,
                                        tag_dictionary=tag_dictionary,
                                        tag_type=tag_type,
                                        use_crf=True)

# 6. initialize trainer
from flair.trainers import SequenceTaggerTrainer

trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus)

# 7. start training
trainer.train('resources/taggers/example-ner',
              learning_rate=0.1,
              mini_batch_size=32,
              max_epochs=10)

Gives the following error :

---------------------------------------------------------------------------
OSError                                   Traceback (most recent call last)
<ipython-input-6-de6d08c3eb5d> in <module>()
     47               learning_rate=0.1,
     48               mini_batch_size=32,
---> 49               max_epochs=10)

.../lib/python3.7/site-packages/flair/trainers/sequence_tagger_trainer.py in train(self, base_path, learning_rate, mini_batch_size, max_epochs, anneal_factor, patience, train_with_dev, embeddings_in_memory, checkpoint, save_final_model, anneal_with_restarts)
    161                 # if we use dev data, remember best model based on dev evaluation score
    162                 if not train_with_dev and dev_score == scheduler.best:
--> 163                     self.model.save(base_path + "/best-model.pt")
    164 
    165             # if we do not use dev data for model selection, save final model

.../lib/python3.7/site-packages/flair/models/sequence_tagger_model.py in save(self, model_file)
    149             'rnn_layers': self.rnn_layers,
    150         }
--> 151         torch.save(model_state, model_file, pickle_protocol=4)
    152 
    153     @classmethod

.../lib/python3.7/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol)
    207         >>> torch.save(x, buffer)
    208     """
--> 209     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
    210 
    211 

.../lib/python3.7/site-packages/torch/serialization.py in _with_file_like(f, mode, body)
    132         f = open(f, mode)
    133     try:
--> 134         return body(f)
    135     finally:
    136         if new_fd:

.../lib/python3.7/site-packages/torch/serialization.py in <lambda>(f)
    207         >>> torch.save(x, buffer)
    208     """
--> 209     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
    210 
    211 

.../lib/python3.7/site-packages/torch/serialization.py in _save(obj, f, pickle_module, pickle_protocol)
    280     pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
    281     pickler.persistent_id = persistent_id
--> 282     pickler.dump(obj)
    283 
    284     serialized_storage_keys = sorted(serialized_storages.keys())

OSError: [Errno 22] Invalid argument
alanakbik commented 5 years ago

Hm interesting - a few questions: What OS are you on and which version of Flair are you using?

Also, could you try a setup with 'glove' embeddings instead of fasttext to see if the error still occurs?

mhham commented 5 years ago

With glove embeddings, the error does not occur.

I am on mac OS 10.14 + python 3.7.0 + flair 0.3.1

stefan-it commented 5 years ago

Hm, we have seen this kind of problems with macos in some issues here...

Does this error still occur when you try to use a recent pytorch version?

mhham commented 5 years ago

Do you mean the nightly version ? I am currently using torch 0.4.1

alanakbik commented 5 years ago

We believe this has something to do with pickle and mac OS - the problem seems to be that pickle cannot store and load very large objects on mac OS. This is why training and loading models trained with 'glove' (relatively small word embeddings) works while it does not work with fasttext (the embedding files are some 3 GB large). We really need to find a solution here, but the problem has been that we mainly develop with ubuntu, so we don't have a setup to reproduce this error.

But given the number of people that use mac OS, I think we need to make this issue an immediate priority. Any help from the mac users in the community is greatly appreciated!

alanakbik commented 5 years ago

Opened #174 for this bug.