Kyubyong / transformer

A TensorFlow Implementation of the Transformer: Attention Is All You Need
Apache License 2.0
4.24k stars 1.29k forks source link

I found a small mistake in the Transformer model. #158

Open rhksgud92 opened 4 years ago

rhksgud92 commented 4 years ago

https://github.com/Kyubyong/transformer/blob/master/model.py In this code from line 176 ~ 181, you are using "==" inside of tensorflow model which won't work. for _ in tqdm(range(self.hp.maxlen2)): logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False) if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1) ys = (_decoder_inputs, y, y_seqlen, sents2)

This would result not stopping at the pad output but keep iterates until the maxlen ends. This is a minor issue but makes the eval function slower.

Use something like this instead would make the eval function faster: logits, y_hat, y, sent2 = tf.cond(tf.equal(y_hat[0][-1], self.token2idx["<pad>"]), lambda: (logits, y_hat, y, sent2), lambda:self.decode(ys, memory, src_masks, False))

bozhenhhu commented 4 years ago

I wonder that why you use y_hat[0][-1], because the first shape of y_hat equals with self.hp.batch_size , why you use every first example to calculate one batch data whether meets 'pad' or not ?

rhksgud92 commented 4 years ago

Sorry, it was supposed to be tf.reduce_sum(y_hat, 1) not y_hat[0][-1]. Since if statement doesn't work in tensorflow version 1.

To make it stop the decode calculation part If sum of all elements are 0 (pad).