sherjilozair / char-rnn-tensorflow

Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow
MIT License
2.64k stars 960 forks source link

Add attention to the model #62

Open ckcz123 opened 7 years ago

ckcz123 commented 7 years ago

What if I'd like to use attention_decoder instead of rnn_decoder?

I wonder how to modify outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm').

What should attention_states be?

john-parton commented 7 years ago

I believe you can add

self.cell = cell = rnn.AttentionCellWrapper(cell, attn_length, state_is_tuple=True)

after

self.cell = cell = rnn.MultiRNNCell(cells, state_is_tuple=True)