minimaxir / textgenrnn

Easily train your own text-generating neural network of any size and complexity on any text dataset with a few lines of code.
Other
4.94k stars 756 forks source link

Keras while_loop() error #49

Open Fraser-Paine opened 6 years ago

Fraser-Paine commented 6 years ago

I've trained this model sucessfully on Windows 10 and Ubuntu with no issues, however when attempting to train on an Azure notebook server I get the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-6-4441bc1684a3> in <module>()
----> 1 textgen = textgenrnn("textgenrnn_weights.hdf5")
      2 textgen.generate()

~/anaconda3_501/lib/python3.6/site-packages/textgenrnn/textgenrnn.py in __init__(self, weights_path, vocab_path, config_path, name)
     63         self.model = textgenrnn_model(self.num_classes,
     64                                       cfg=self.config,
---> 65                                       weights_path=weights_path)
     66         self.indices_char = dict((self.vocab[c], c) for c in self.vocab)
     67 

~/anaconda3_501/lib/python3.6/site-packages/textgenrnn/model.py in textgenrnn_model(num_classes, cfg, context_size, weights_path, dropout, optimizer)
     27     for i in range(cfg['rnn_layers']):
     28         prev_layer = embedded if i is 0 else rnn_layer_list[-1]
---> 29         rnn_layer_list.append(new_rnn(cfg, i+1)(prev_layer))
     30 
     31     seq_concat = concatenate([embedded] + rnn_layer_list, name='rnn_concat')

~/anaconda3_501/lib/python3.6/site-packages/keras/layers/recurrent.py in __call__(self, inputs, initial_state, constants, **kwargs)
    498             additional_inputs += constants
    499             self.constants_spec = [InputSpec(shape=K.int_shape(constant))
--> 500                                    for constant in constants]
    501             self._num_constants = len(constants)
    502             additional_specs += self.constants_spec

~/anaconda3_501/lib/python3.6/site-packages/keras/engine/base_layer.py in __call__(self, inputs, **kwargs)

~/anaconda3_501/lib/python3.6/site-packages/keras/layers/recurrent.py in call(self, inputs, mask, training, initial_state)
   2110                   'recurrent_dropout': self.recurrent_dropout,
   2111                   'implementation': self.implementation}
-> 2112         base_config = super(LSTM, self).get_config()
   2113         del base_config['cell']
   2114         return dict(list(base_config.items()) + list(config.items()))

~/anaconda3_501/lib/python3.6/site-packages/keras/layers/recurrent.py in call(self, inputs, mask, training, initial_state, constants)
    607                 states = [states]
    608             else:
--> 609                 states = list(states)
    610             return [output] + states
    611         else:

~/anaconda3_501/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in rnn(step_function, inputs, initial_states, go_backwards, mask, constants, unroll, input_length)
   2955 
   2956 def sigmoid(x):
-> 2957     """Element-wise sigmoid.
   2958 
   2959     # Arguments

TypeError: while_loop() got an unexpected keyword argument 'maximum_iterations'

I found this issue which appeared to be similar and attempted the suggested fix. However I then get the following warning textgenrnn 1.3.1 has requirement keras>=2.1.5, but you'll have keras 2.1.2 which is incompatible. The same error as above persists also. I've tried a few other versions of Tensorflow and Keras including the newest of each but none work.

The code I'm attempting is the absolute minimal code for this library and the issue persists

from textgenrnn import textgenrnn import pandas as pd import numpy as np textgen = textgenrnn() textgen.generate()

Tianmaru commented 5 years ago

Hey Fraser-Paine, I also had some trouble when I tried to load a trained model - this error message was among many others. It worked for me with:

Keras: 2.1.6 tensorflow: 1.11.0rc2 textgenrnn: 1.4

I hope this helps :)