suragnair / seqGAN

A simplified PyTorch implementation of "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient." (Yu, Lantao, et al.)
642 stars 149 forks source link
deep-learning gan generative-adversarial-network generative-model natural-language-processing natural-language-understanding nlp policy-gradient seqgan

seqGAN

A PyTorch implementation of "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient." (Yu, Lantao, et al.). The code is highly simplified, commented and (hopefully) straightforward to understand. The policy gradients implemented are also much simpler than in the original work (https://github.com/LantaoYu/SeqGAN/) and do not involve rollouts- a single reward is used for the entire sentence (inspired by the examples in http://karpathy.github.io/2016/05/31/rl/).

The architectures used are different than those in the orignal work. Specifically, a recurrent bidirectional GRU network is used as the discriminator.

The code performs the experiment on synthetic data as described in the paper.

You are encouraged to raise any doubts regarding the working of the code as Issues.

To run the code:

python main.py

main.py should be your entry point into the code.

Hacks and Observations

The following hacks (borrowed from https://github.com/soumith/ganhacks) seem to have worked in this case:

Sample Learning Curve

Learning curve obtained after MLE training for 100 epochs followed by adversarial training. (Your results may vary!)

alt tag