suriyadeepan / augmented_seq2seq

enhance seq2seq model for open ended dialog generation
GNU General Public License v3.0
28 stars 9 forks source link

How to get encoder_final_outputs? #1

Open 0b01 opened 7 years ago

0b01 commented 7 years ago

Hello @suriyadeepan, nice repo you got here. I am trying to resolve https://github.com/tensorflow/tensorflow/issues/10862 this bug. I find you functional scan method refreshing. My only question: how to get the outputs from the rnn?

For a typical bidirectional_dynamic_rnn, the return result is output and state:

with tf.variable_scope("ENCODE"):
    enc_cells_fw = []
    for i in range(0, encoder_depth):
        with tf.variable_scope('enc_RNN_{}'.format(i)):
            cell = tf.contrib.rnn.LSTMCell(hidden_dim)  # Or LSTMCell(hidden_dim)
            cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1.0-dropout)
            enc_cells_fw.append(cell)
    enc_cell_fw = tf.contrib.rnn.MultiRNNCell(enc_cells_fw)
    enc_cells_bw = []
    for i in range(0, encoder_depth):
        with tf.variable_scope('enc_RNN_{}'.format(i)):
            cell = tf.contrib.rnn.LSTMCell(hidden_dim)  # Or LSTMCell(hidden_dim)
            cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1.0-dropout)
            enc_cells_bw.append(cell)
    enc_cell_bw = tf.contrib.rnn.MultiRNNCell(enc_cells_bw)

    enc_inp_len = np.array([seq_length_in for _ in range(batch_size)])

    ((encoder_fw_outputs,
      encoder_bw_outputs),
     (encoder_fw_final_state,
      encoder_bw_final_state)) = (
        tf.nn.bidirectional_dynamic_rnn(cell_fw=enc_cell_fw,
                                        cell_bw=enc_cell_bw,
                                        inputs=enc_inp,
                                        sequence_length=enc_inp_len,
                                        dtype=tf.float32)
        )
    encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)

    encoder_final_state_c = tf.concat((encoder_fw_final_state[0].c, encoder_bw_final_state[0].c), 1)

    encoder_final_state_h = tf.concat((encoder_fw_final_state[0].h, encoder_bw_final_state[0].h), 1)

    encoder_final_state = tf.contrib.rnn.LSTMStateTuple(
        c=encoder_final_state_c,
        h=encoder_final_state_h
    )
0b01 commented 7 years ago

I have figured it out:

    enc_cells_fw = []
    for i in range(0, encoder_depth):
        with tf.variable_scope('enc_RNN_{}'.format(i)):
            cell = tf.contrib.rnn.LSTMCell(hidden_dim)  # Or LSTMCell(hidden_dim)
            cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1.0-dropout)
            enc_cells_fw.append(cell)
    enc_cell_fw = tf.contrib.rnn.MultiRNNCell(enc_cells_fw, state_is_tuple=True)
    enc_cells_bw = []
    for i in range(0, encoder_depth):
        with tf.variable_scope('enc_RNN_{}'.format(i)):
            cell = tf.contrib.rnn.LSTMCell(hidden_dim)  # Or LSTMCell(hidden_dim)
            cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=1.0-dropout)
            enc_cells_bw.append(cell)
    enc_cell_bw = tf.contrib.rnn.MultiRNNCell(enc_cells_bw, state_is_tuple=True)

    init_state = enc_cell_fw.zero_state(batch_size=batch_size, dtype=tf.float32)

    # transpose encoder inputs to time-major
    enc_inp_t = tf.transpose(enc_inp, [1,0,2])
    #
    # der bi encoder
    with tf.variable_scope('encoder-fw') as scope: # forward sequence
        enc_output_fw, enc_states_fw = tf.scan(lambda (_, st_1), x : enc_cell_fw(x, st_1),
                enc_inp_t, initializer=(tf.zeros(shape=[batch_size, hidden_dim]), init_state))

    with tf.variable_scope('encoder-bw') as scope: # backward sequence
        enc_output_bw, enc_states_bw = tf.scan(lambda (_, st_1), x : enc_cell_bw(x, st_1),
                            tf.reverse(enc_inp_t, axis=[0]), # <- reverse inputs
                            initializer=(tf.zeros(shape=[batch_size, hidden_dim]), init_state))

    enc_output_fw = tf.transpose(enc_output_fw, [1,0,2])
    enc_output_bw = tf.transpose(enc_output_bw, [1,0,2])
    encoder_outputs = tf.concat([enc_output_fw, enc_output_bw], 2)

    # project context
    Wc = tf.get_variable('Wc', shape=[2, encoder_depth, hidden_dim*2, hidden_dim*2],
                        initializer=tf.contrib.layers.xavier_initializer())

    # extract context [get final state; project c,h to [hidden_dim]; list->tuple]
    encoder_final_state = []
    for layer in range(encoder_depth):
        enc_c = tf.concat( (enc_states_fw[layer].c[-1], enc_states_bw[layer].c[-1]), 1)
        enc_c = tf.matmul(enc_c, Wc[0][layer])
        enc_h = tf.concat( (enc_states_fw[layer].h[-1], enc_states_bw[layer].h[-1]), 1)
        enc_h = tf.matmul(enc_h, Wc[1][layer])
        encoder_final_state.append(tf.contrib.rnn.LSTMStateTuple(c = enc_c, h = enc_h))
    # convert list to tuple - eww!
    encoder_final_state = tuple(encoder_final_state)