hunkim / word-rnn-tensorflow

Multi-layer Recurrent Neural Networks (LSTM, RNN) for word-level language models in Python using TensorFlow.
MIT License
1.3k stars 494 forks source link

Beam search fixes #51

Closed victorkwan closed 7 years ago

victorkwan commented 7 years ago

Fixes beam search through a number of changes:

  1. Beam search as a pick option is no longer grouped with the other sampling methods. This was problematic for two reasons. First, beam search does not return a single word pick - it looks for the best beam of up to num words. Second, adding the prime text outside of the beamsearch() method would mean that the prime text has no impact on the beam search scores. As such, I've separated the different pick options.
  2. We now pass a predict function to the BeamSearch class, as well as the initial state and the prime text. The predict() function allows us to perform a single computation on the RNN to figure out next-step probabilities and state. In the context of beamsearch(), we use the predict() function to progress each beam. We treat the state variable as more or less opaque, using it only with the predict() function.
hunkim commented 7 years ago

Thanks for the fix. I also really appreciate the test case. Could you also add some examples (in README.md) with/without the beam search?

victorkwan commented 7 years ago

Done. I've updated the README accordingly.

normanheckscher commented 7 years ago

I've merged this into the master1.0 branch and it appears to work with the tensorflow release candidate from earlier today. I'm a little hesitant to merge the master1.0 into the master branch yet as there appears to be some work going on with the seq2seq sections of tensorflow in the master tensorflow branch.