Hironsan / anago

Bidirectional LSTM-CRF and ELMo for Named-Entity Recognition, Part-of-Speech Tagging and so on.
https://anago.herokuapp.com/
MIT License
1.48k stars 371 forks source link

AttributeError: 'BiLSTMCRF' object has no attribute 'fit_generator' #79

Closed willaaam closed 5 years ago

willaaam commented 5 years ago

System information

Describe the problem

I think there is a bug in anago 1.0.7 (retrieved via pip). Since anago 1.0.7 I get the traceback below. This doesn't happen with anago 1.0.6.

C:\temp\anago>python trainbilstmmodel.py
Using TensorFlow backend.
Loading dataset...
Transforming datasets...
Building a model.
Training the model...
Traceback (most recent call last):
  File "trainbilstmmodel.py", line 69, in <module>
    main(args)
  File "trainbilstmmodel.py", line 37, in main
    trainer.train(x_train, y_train, x_valid, y_valid)
  File "C:\Program Files\Python36\lib\site-packages\anago\trainer.py", line 47, in train
    self._model.fit_generator(generator=train_seq,
AttributeError: 'BiLSTMCRF' object has no attribute 'fit_generator'

Based on the following script (slightly adjusted example):

"""
Example from training to saving.
"""
import argparse
import os

from anago.utils import load_data_and_labels
from anago.models import BiLSTMCRF
from anago.preprocessing import IndexTransformer
from anago.trainer import Trainer
from sklearn.model_selection import train_test_split

def main(args):
    print('Loading dataset...')
    X,y = load_data_and_labels(args.train_data)
    x_train, x_valid, y_train, y_valid = train_test_split(X, y, test_size=0.1, random_state=42)

    print('Transforming datasets...')
    p = IndexTransformer(use_char=args.no_char_feature)
    p.fit(x_train, y_train)

    print('Building a model.')
    model = BiLSTMCRF(char_embedding_dim=args.char_emb_size,
                      word_embedding_dim=args.word_emb_size,
                      char_lstm_size=args.char_lstm_units,
                      word_lstm_size=args.word_lstm_units,
                      char_vocab_size=p.char_vocab_size,
                      word_vocab_size=p.word_vocab_size,
                      num_labels=p.label_size,
                      dropout=args.dropout,
                      use_char=args.no_char_feature,
                      use_crf=args.no_use_crf)
    model.build()

    print('Training the model...')
    trainer = Trainer(model, preprocessor=p)
    trainer.train(x_train, y_train, x_valid, y_valid)

    print('Saving the model...')
    model.save(args.weights_file, args.params_file)
    p.save(args.preprocessor_file)

if __name__ == '__main__':
    DATA_DIR = os.path.join(os.path.dirname(__file__), './')
    parser = argparse.ArgumentParser(description='Training a model')
    parser.add_argument('--train_data', default=os.path.join(DATA_DIR, 'train.txt'), help='training data')
    parser.add_argument('--valid_data', default=os.path.join(DATA_DIR, 'valid.txt'), help='validation data')
    parser.add_argument('--weights_file', default='weights.h5', help='weights file')
    parser.add_argument('--params_file', default='params.json', help='parameter file')
    # Training parameters
    parser.add_argument('--loss', default='categorical_crossentropy', help='loss')
    parser.add_argument('--optimizer', default='adam', help='optimizer')
    parser.add_argument('--max_epoch', type=int, default=15, help='max epoch')
    parser.add_argument('--batch_size', type=int, default=32, help='batch size')
    parser.add_argument('--checkpoint_path', default=None, help='checkpoint path')
    parser.add_argument('--log_dir', default=None, help='log directory')
    parser.add_argument('--early_stopping', action='store_true', help='early stopping')
    # Model parameters
    parser.add_argument('--char_emb_size', type=int, default=25, help='character embedding size')
    parser.add_argument('--word_emb_size', type=int, default=100, help='word embedding size')
    parser.add_argument('--char_lstm_units', type=int, default=25, help='num of character lstm units')
    parser.add_argument('--word_lstm_units', type=int, default=100, help='num of word lstm units')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
    parser.add_argument('--no_char_feature', action='store_false', help='use char feature')
    parser.add_argument('--no_use_crf', action='store_false', help='use crf layer')

    args = parser.parse_args()
    main(args)
Hironsan commented 5 years ago

Thank you to let me know the problem. This problem was solved.