Open liuhuang31 opened 5 years ago
Hi,
Could you provide more details about which function of which file you are changing?
Hi, the change file is 'modules_tf.py'. The code i added is below.
output = encoder_decoder_archi_gan(inputs, is_train)
output = tf.tanh(tf.layers.batch_normalization(tf.layers.dense(output, config.output_features, name = "Fu_F", kernel_initializer=tf.random_normal_initializer(stddev=0.02)), training = is_train, name = "bn_fu_out"))
output = tf.squeeze(output)
# add lstm layer
lstm_cell = tf.contrib.rnn.BasicLSTMCell(64, forget_bias=1.0, state_is_tuple=True)
init_state = lstm_cell.zero_state(batch_size=config.batch_size, dtype=tf.float32)
lstm_out, final_state = tf.nn.dynamic_rnn(lstm_cell, output, initial_state=init_state, time_major=False)
lstm_out = tf.tanh(tf.layers.batch_normalization(tf.layers.dense(lstm_out, config.output_features, name = "T_GRU", kernel_initializer=tf.random_normal_initializer(stddev=0.02)), training=is_train, name = "gru_out"))
return lstm_out
I would define an LSTM function as follows:
def bi_static_stacked_RNN(x, scope='RNN', lstm_size = config.lstm_size): """ Input and output in batch major format """ with tf.variable_scope(scope): x = tf.unstack(x, config.max_phr_len, 1)
output = x
num_layer = 2
# for n in range(num_layer):
lstm_fw = tf.nn.rnn_cell.LSTMCell(lstm_size, state_is_tuple=True)
lstm_bw = tf.nn.rnn_cell.LSTMCell(lstm_size, state_is_tuple=True)
_initial_state_fw = lstm_fw.zero_state(config.batch_size, tf.float32)
_initial_state_bw = lstm_bw.zero_state(config.batch_size, tf.float32)
output, _state1, _state2 = tf.contrib.rnn.static_bidirectional_rnn(lstm_fw, lstm_bw, output,
initial_state_fw=_initial_state_fw,
initial_state_bw=_initial_state_bw,
scope='BLSTM_'+scope)
output = tf.stack(output)
output_fw = output[0]
output_bw = output[1]
output = tf.transpose(output, [1,0,2])
return output
And then pass the output to it as: output = bi_static_stacked_RNN(output, scope = "scope")
Hi, thanks the share. when i add a lstm layer, met "Segmentation fault"
Add a simple lstm layer in the full_network, the code is below:
output = tf.squeeze(output) // origin code, doesn't change. // lstm, the code i add. lstm_cell = tf.contrib.rnn.BasicLSTMCell(64, forget_bias=1.0, state_is_tuple=True) init_state = lstm_cell.zero_state(batch_size=config.batch_size, dtype=tf.float32) lstm_out, final_state = tf.nn.dynamic_rnn(lstm_cell, output, initial_state=init_state, time_major=False) print("-----------lstm_out.shape------------", lstm_out.shape) // (30, 128, 64) return lstm_out
when train, i meet "Segmentation fault", and there is no other error information. it's amazing, the code i add is simple and no fault.