sherjilozair / char-rnn-tensorflow

Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow
MIT License
2.64k stars 960 forks source link

loop function #54

Open fujimotomh opened 7 years ago

fujimotomh commented 7 years ago

Hi, I was wondering if someone could confirm my suspicion. I think this code in model.py is not ever used with the way sampling is done currently.

        def loop(prev, _):
            prev = tf.matmul(prev, softmax_w) + softmax_b
            prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
            return tf.nn.embedding_lookup(embedding, prev_symbol)

        outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')

When I change to this, training and sampling seems to work fine

        # def loop(prev, _):
        #     prev = tf.matmul(prev, softmax_w) + softmax_b
        #     prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
        #     return tf.nn.embedding_lookup(embedding, prev_symbol)

        outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, scope='rnnlm')

Looking at the source for seq2seq.rnn_decoder, if input has length 1 (which it does when infer == True), the loop function is never used. Am I missing something? It almost looks like this code could replicate this paper.

Beitadoge commented 5 years ago

@fujimotomh i agree with you ,i think the loop_function is not necessary,it has never been used