tensorflow / nmt

TensorFlow Neural Machine Translation Tutorial
Apache License 2.0
6.37k stars 1.96k forks source link

Using inference-style GreedyEmbeddingHelper instead of TrainingHelper during training #242

Closed ghost closed 6 years ago

ghost commented 6 years ago

Because of https://github.com/tensorflow/nmt/issues/241, I am trying to use an inference-style GreedyEmbeddingHelper instead of the standard TrainingHelper during training, evaluation, and inference (see model code snippet below).

This works, most of the time, except for certain batches where it crashes with the following error: tensorflow.python.framework.errors_impl.InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [700,11] and labels shape [76200]

This probably has to do with my "misuse" of the GreedyEmbeddingHelper and it producing a prediction of the incorrect length for the cross-entropy softmax at certain timesteps; is there a viable method to avoid this error? I realize GreedyEmbeddingHelper is not optimal for training (there's a reason TrainingHelper exists), but hopefully the linked issue provides some explanation.

def __init__(self, hparams, iterator, mode):
    tf.set_random_seed(hparams.graph_seed)
    source, target_in, target_out, source_lengths, target_lengths = iterator.get_next()
    true_batch_size = tf.size(source_lengths)

    # Lookup embeddings
    embedding_encoder = tf.get_variable("embedding_encoder", [hparams.src_vsize, hparams.src_emsize])
    encoder_emb_inp = tf.nn.embedding_lookup(embedding_encoder, source)
    embedding_decoder = tf.get_variable("embedding_decoder", [hparams.tgt_vsize, hparams.tgt_emsize])
    decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, target_in)

    # Build and run Encoder LSTM
    encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
    encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, sequence_length=source_lengths, dtype=tf.float32)

    # Build and run Decoder LSTM with Helper and output projection layer
    decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
    projection_layer = layers_core.Dense(hparams.tgt_vsize, use_bias=False)

    #if mode is 'TRAIN' or mode is 'EVAL':  # then decode using TrainingHelper
         #helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, sequence_length=target_lengths)
    #elif mode is 'INFER':  # then decode using Beam Search
         #helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding_decoder, tf.fill([true_batch_size], hparams.sos), hparams.eos)
    # ALWAYS USE INFERENCE-STYLE DECODER INSTEAD OF TRAININGHELPER
    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding_decoder, tf.fill([true_batch_size], hparams.sos), hparams.eos)

    decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=projection_layer)
    outputs, _, self.test = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=tf.reduce_max(target_lengths))
    logits = outputs.rnn_output

    if mode is 'TRAIN' or mode is 'EVAL':  # then calculate loss
        crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_out, logits=logits)
        target_weights = tf.sequence_mask(target_lengths, maxlen=tf.shape(target_out)[1], dtype=logits.dtype)
        self.loss = tf.reduce_sum((crossent * target_weights)) / tf.cast(true_batch_size, tf.float32)

    if mode is 'TRAIN':  # then calculate/clip gradients, then optimize model
        params = tf.trainable_variables()
        gradients = tf.gradients(self.loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, hparams.max_gradient_norm)

        optimizer = tf.train.AdamOptimizer(hparams.l_rate)
        self.update_step = optimizer.apply_gradients(zip(clipped_gradients, params))

    if mode is 'EVAL' or mode is 'INFER':  # then allow access to input/output tensors to printout
        self.src = source
        self.tgt = target_out
        self.preds = tf.argmax(logits, axis=2)

    # Designate a saver operation
    self.saver = tf.train.Saver(tf.global_variables())
ghost commented 6 years ago

Doing this is a bad idea and loses too much information during training.