Open rhksgud92 opened 4 years ago
I wonder that why you use y_hat[0][-1], because the first shape of y_hat equals with self.hp.batch_size , why you use every first example to calculate one batch data whether meets 'pad' or not ?
Sorry, it was supposed to be tf.reduce_sum(y_hat, 1) not y_hat[0][-1]. Since if statement doesn't work in tensorflow version 1.
To make it stop the decode calculation part If sum of all elements are 0 (pad).
https://github.com/Kyubyong/transformer/blob/master/model.py In this code from line 176 ~ 181, you are using "==" inside of tensorflow model which won't work.
for _ in tqdm(range(self.hp.maxlen2)): logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False) if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1) ys = (_decoder_inputs, y, y_seqlen, sents2)
This would result not stopping at the pad output but keep iterates until the maxlen ends. This is a minor issue but makes the eval function slower.
Use something like this instead would make the eval function faster:
logits, y_hat, y, sent2 = tf.cond(tf.equal(y_hat[0][-1], self.token2idx["<pad>"]), lambda: (logits, y_hat, y, sent2), lambda:self.decode(ys, memory, src_masks, False))