suragnair / seqGAN

A simplified PyTorch implementation of "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient." (Yu, Lantao, et al.)
641 stars 149 forks source link

Monte Carlo Rollouts in PyTorch? #4

Closed cbsudux closed 6 years ago

cbsudux commented 6 years ago

Hi,

I am planning on including MCTS in your implementation to improve the training of the Generator. Do you have any example implementations/ suggestions?

Cheers

suragnair commented 6 years ago

Hi

You can refer to the original TensorFlow implementation by the author https://github.com/LantaoYu/SeqGAN .

Essentially you'll have to change the sample function in the generator to generate conditional samples (i.e. given the first T tokens, generate the remaining tokens). Then you'll have to change the train_generator_PG function in main.py. You'll write a loop in which every iteration appends a new token to the sentence, performs rollouts, and collects rewards for the new token. Good luck!