yaserkl / RLSeq2Seq

Deep Reinforcement Learning For Sequence to Sequence Models
https://arxiv.org/abs/1805.09461
MIT License
767 stars 160 forks source link

How to train a NMT model ? #12

Closed kobenaxie closed 6 years ago

kobenaxie commented 6 years ago

As for a NMT task, with this code, whther it means I just need to replace the data[for summarization task ] with the data for NMT task, like WMT14 ?

yaserkl commented 6 years ago

For NMT, you have to make sure that you are not using the pointer generator since it's not the right choice for training a machine translation. What you end up after deactivating pointer-generator is a simple seq2seq attention based model which is not gonna give you the state-of-the-art results in this specific task. Since this library doesn't have modules for Transformer unit (which currently yields the state-of-the-art results in NMT), I would suggest to try to work on projects like tensor2tensor and add RL training the way that I did in this project.

kobenaxie commented 6 years ago

OK, thank you for your suggestion ~

yashkumaratri commented 5 years ago

Did anyone set the pointer generation False and got into this

max_size of vocab was specified as 50000; we now have 50000 words. Stopping reading.
Finished constructing vocabulary of 50000 total words. Last word added: deadlines,
creating model...
INFO:tensorflow:Building graph...
Writing word embedding metadata file to /home/paperspace/intradecoder-temporalattention-withpretraining/train/vocab_metadata.tsv...
WARNING:tensorflow:From /home/paperspace/.local/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py:430: calling reverse_sequence (from tensorflow.python.ops.array_ops) with seq_dim is deprecated and will be removed in a future version.
Instructions for updating:
seq_dim is deprecated, use seq_axis instead
WARNING:tensorflow:From /home/paperspace/.local/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py:454: calling reverse_sequence (from tensorflow.python.ops.array_ops) with batch_dim is deprecated and will be removed in a future version.
Instructions for updating:
batch_dim is deprecated, use batch_axis instead
Traceback (most recent call last):
  File "src/run_summarization.py", line 795, in <module>
    tf.app.run()
  File "/home/paperspace/.local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 125, in run
    _sys.exit(main(argv))
  File "src/run_summarization.py", line 792, in main
    seq2seq.main(unused_argv)
  File "src/run_summarization.py", line 745, in main
    self.setup_training()
  File "src/run_summarization.py", line 273, in setup_training
    self.model.build_graph() # build the graph
  File "/home/paperspace/RL/RLSeq2Seq/src/model.py", line 468, in build_graph
    self._add_seq2seq()
  File "/home/paperspace/RL/RLSeq2Seq/src/model.py", line 279, in _add_seq2seq
    self.sampling_rewards, self.greedy_rewards) = self._add_decoder(emb_dec_inputs, embedding)
  File "/home/paperspace/RL/RLSeq2Seq/src/model.py", line 194, in _add_decoder
    self._max_art_oovs,
AttributeError: 'SummarizationModel' object has no attribute '_max_art_oovs'