martin-gorner / tensorflow-rnn-shakespeare

Code from the "Tensorflow and deep learning - without a PhD, Part 2" session on Recurrent Neural Networks.
Apache License 2.0
534 stars 249 forks source link

Arg max vs. sample_from_probabilities #33

Closed jammygrams closed 6 years ago

jammygrams commented 6 years ago

This isn't an issue but a question on model understanding - please let me know if I should raise this somewhere else.

When training, we input a string of characters (length SEQLEN), and predict the next character one at a time. Our prediction at each step is softmax, i.e. probabilities that the prediction is any of ALPHASIZE characters. During training, we take the arg max of this distribution, and then calculate accuracy by comparing that prediction with ground truth. However, accuracy plateaus at ~65%, and if we look at predictions they're not fluent english (with 35% characters being wrong), even after many epochs.

For inference, we start with a random input (some character) and generate characters one at a time, using each generated character as input to the next time step. Here, our prediction is not the arg max of the softmax distribution, instead we randomly choose from the 'top n' probabilities ( 'sample_from_probabilities' function in my_txtutils.py). Because of this when we inference, the same weights that couldn't produce fluent english in training (and only 65% accuracy), can produce completely fluent english words and phrases, even after a few batches. What's the reason for this difference?

I thought the intention of 'sample_from_probabilities' is just to introduce randomness, so we can generate lots of different samples. However, arg max doesn't generate fluent English while 'sample_from_probabilities' does, so I'm confused how it does this.

Please let me know if I can clarify or if I've misunderstood anything.

martin-gorner commented 6 years ago

Hi

Yes, you are correct, argmax+accuracy is a very bad metric. Fortunately, it is not used anywhere during training other than for display, to give you an indication the something good is happening because some metric is going up. A better metric to use here would be "perplexity".

During inference, "sampling from probabilities" is a critical step. If you just use argmax (top probability) you get a sequence of the most probable words in english. Something like "the a for the is the...". That is not what you want. Sampling from probabilities indeed introduces randomness, which opens up the vocabulary of words and generates much more english-sounding sentences.

An even better approach would be to use an algorithm like "beam search" to generate the most probable sequence of n characters instead of just the most probable next character. For example, after "for" the most probable next char is " " (space) because "for" is such a frequent word. However, based on prior context, the most probable sequence could be "form" or "fortran" or indeed "for_". Beam search takes multiple paths through the sequence of probabilities and selects the most probable sequences. In a way, sampling from probabilities implements a weak form of beam search, one that will consider the most probable sequences by chance: there is a chance it picks "form" rather than "for " even though the most probable char after "for" is " " (space). If " " space had 50% probability and "m" was at 35%, sampling from probabilities can still take the "m".

jammygrams commented 6 years ago

Thanks for the prompt answer - very helpful!