tensorflow / nmt

TensorFlow Neural Machine Translation Tutorial
Apache License 2.0
6.37k stars 1.96k forks source link

How to get the encoder hidden states? #298

Open peterpan2018 opened 6 years ago

peterpan2018 commented 6 years ago

Hello,

I want to train the model and then use it as an auto-encoder but I was wondering how we can get the encoder hidden states at inference time, i.e, given a new sentence, print out the encoder hidden states for that sentence.

I am struggling with this issue for a week. Any help would be much appreciated.

FallakAsad commented 6 years ago

Any updates on this issue?

FallakAsad commented 6 years ago

I was able to get the encoder states by making the 'encoder_state' parameter of '_build_decoder' function a class property of class 'BaseModel' and then changing the "BaseModel.infer" function at https://github.com/tensorflow/nmt/blob/master/nmt/model.py#L670 to as follows:

Current definition of infer()

  def infer(self, sess):
    assert self.mode == tf.contrib.learn.ModeKeys.INFER
    output_tuple = InferOutputTuple(infer_logits=self.infer_logits,
                                    infer_summary=self.infer_summary,
                                    sample_id=self.sample_id,
                                    sample_words=self.sample_words)
    return sess.run(output_tuple)

Updated definition of infer()

  def infer(self, sess):
    assert self.mode == tf.contrib.learn.ModeKeys.INFER
    output_tuple = InferOutputTuple(infer_logits=self.infer_logits,
                                    infer_summary=self.infer_summary,
                                    sample_id=self.sample_id,
                                    sample_words=self.sample_words)
    result = sess.run([output_tuple, self.encoder_state[0].c, self.encoder_state[0].h, self.encoder_state[1].c, self.encoder_state[1].h])
    return result[0]

Now the result array will not only contains the output_tuple but also the encoder states.