farizrahman4u / recurrentshop

Framework for building complex recurrent neural networks with Keras
MIT License
765 stars 218 forks source link

using teacher_force failed when using tensorflow as backend #81

Open datum-gravitydata opened 7 years ago

datum-gravitydata commented 7 years ago

I am using teacher_force with tensorflow backend. But it failed when fit function called. when using theano as backend, my source code worked. my condition is as followeing. Keras 2.0.6 tensorflow 1.2.1 theano 0.10.0b1

and my minimal code is like below.

input_length = MAX_SL
input_dim = len(corpus)
output_length = MAX_TL
output_dim = len(corpus)
LSTM_hidden_dim = 50

# make recurrent model
input_node = Input((input_length, input_dim), name = "main_input_decoder") 
readout_tm = Input([output_dim], name = "readout_input") 
lstm_st1 = Input((LSTM_hidden_dim,), name = "lstm_hidden_1") 
lstm_st2 = Input((LSTM_hidden_dim,), name = "lstm_hidden_2") 
s, lstm_st_t1, lstm_st_t2 = rs.LSTMCell(LSTM_hidden_dim)([readout_tm, lstm_st1, lstm_st2]) 
output = Dense(output_dim)(s)
output_model = rs.RecurrentModel(input = input_node, output=output 
                     , readout_input=readout_tm, return_sequences = True 
                     , initial_states = [lstm_st1, lstm_st2]
                     , final_states = [lstm_st_t1, lstm_st_t2] 
                     , decode = True, output_length = MAX_TL, unroll = False
                    , teacher_force = True
                     , stateful = False, name = "decoder_model") 

#use above recurrent model
input_node3 = Input((input_length, input_dim))
ground_truth = Input([output_length, output_dim], name="ground_truth")
decoded = output_model(input_node3, ground_truth=ground_truth)

model = Model([input_node3, ground_truth], decoded) 
model.compile(optimizer='adam', loss='binary_crossentropy') 

model.fit([inputs, outputs], outputs, epochs=1000, verbose = 1)

error is like this.

InvalidArgumentError: The node 'decoder_model/while_1/Variable_1/Assign' has inputs from different frames. The input 'decoder_model/while_1/Const_1' is in frame 'decoder_model/while_1/decoder_model/while_1/'. The input 'decoder_model/while_1/Variable_1' is in frame ''.

Is this bag? or is there wrong point on my code?

alijkhalil commented 7 years ago

hey. i got the same error (with TF as the backend) using the exact example provided in the "teacher_force" documentation. please @farizrahman4u , fix asap. thanks!