sherjilozair / char-rnn-tensorflow

Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow
MIT License
2.64k stars 960 forks source link

Unable to test with GRU or RNN cells #85

Closed pisiiki closed 6 years ago

pisiiki commented 7 years ago

I am on tensorflow 1.0, however it failed on 0.12 too.

Traceback (most recent call last):
  File "train.py", line 114, in <module>
    main()
  File "train.py", line 48, in main
    train(args)
  File "train.py", line 98, in train
    for i, (c, h) in enumerate(model.initial_state):
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 516, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

Regards.

ubergarm commented 7 years ago

I can get various cells working. Using Tensorflow 1.0.0 GPU build. I did have to do a tweak for the most recent version as configuring multiple layers seems to have changed a bit. Check out PR #89 . I also added the latest NASCell. Hope you get it working!

tgy commented 7 years ago

I have the same problem with tf 1.0.0 CPU, the problem is that GRUCell's state is not a tuple like in LSTMCell. Thus, when trying to enumerate(model.initial_state), it's not working.

ubergarm commented 7 years ago

Thanks for the report @tgy I'll gave it another look and confirmed it is broken for both rnn and gru models.

Filing as a bug.

Feel free to submit a PR if you have a fix that works across all 4 cell types including lstm and nas.

Traceback (most recent call last):
  File "train.py", line 140, in <module>
    main()
  File "train.py", line 55, in main
    train(args)
  File "train.py", line 115, in train
    for i, (c, h) in enumerate(model.initial_state):
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 502, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
pisiiki commented 7 years ago

@ubergarm, a workaround is to avoid state_is_tuple=True. However tf will complain. I have also some code that supports some of the rnn cells and I did that but I think I will end up writing code for the two cases. Also NASCell has been updated recently, I think it is using the tuple state now (on github master branch). so it should work if you rebuild tf.

tgy commented 7 years ago

There's a Google seq2seq implementation that has been open sourced a few days ago. The code is very generic and might be a good source of inspiration for making char-rnn model's cell type generic.

ubergarm commented 7 years ago

Thanks @pisiiki I'll give it a try, I did remove that state_is_tuple=True which probably broke two of the cells. If nas supports it now then I can put it back and hopefully all will be well again.

Thanks @tgy I wanna check that out, its hard to come into an existing repo with many features and figure out the basics: this reference should help.