ChenChengKuan / SeqGAN_tensorflow

SeqGAN tensorflow implementation
MIT License
107 stars 37 forks source link
natural-language-generation reinforcement-learning seqgan seqgan-tensorflow

SeqGAN_tensorflow

This code is used to reproduce the result of synthetic data experiments in "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient" (Yu et.al). It replaces the original tensor array implementation with higher level tensorflow API for better flexibility.

Introduction

The baisc idea of SeqGAN is to regard sequence generator as an agent in reinforcement learning. To train this agent, it applies REINFORCE (Williams, 1992) algorithm to train the generator and a discriminator is trained to provide the reward. To calculate the reward of partially generated sequence, Monte-Carlo sampling is used to rollout the unfinished sequence to get the estimated reward. seqgan

Some works based on training method used in SeqGAN:

Prerequisites

Results

The output in experiment.log would be something similar to below, which is close to reported result in original implementation

pre-training...
epoch:  0   nll:    10.1971
epoch:  5   nll:    9.4694
epoch:  10  nll:    9.2169
epoch:  15  nll:    9.17986
epoch:  20  nll:    9.16206
epoch:  25  nll:    9.1344
epoch:  30  nll:    9.12127
epoch:  35  nll:    9.0948
epoch:  40  nll:    9.10186
epoch:  45  nll:    9.10108
epoch:  50  nll:    9.0971
epoch:  55  nll:    9.11246
epoch:  60  nll:    9.1182
epoch:  65  nll:    9.10095
epoch:  70  nll:    9.09244
epoch:  75  nll:    9.08816
epoch:  80  nll:    9.10319
epoch:  85  nll:    9.08916
epoch:  90  nll:    9.08348
epoch:  95  nll:    9.09661
epoch:  100 nll:    9.10361
epoch:  105 nll:    9.11718
epoch:  110 nll:    9.10492
epoch:  115 nll:    9.1038
adversarial training...
epoch:  0   nll:    9.09558
epoch:  5   nll:    9.03083
epoch:  10  nll:    8.96725
epoch:  15  nll:    8.91415
epoch:  20  nll:    8.87554
epoch:  25  nll:    8.82305
epoch:  30  nll:    8.76805
epoch:  35  nll:    8.73597
epoch:  40  nll:    8.71933
epoch:  45  nll:    8.71653
epoch:  50  nll:    8.71746
epoch:  55  nll:    8.7036
epoch:  60  nll:    8.68666
epoch:  65  nll:    8.68931
epoch:  70  nll:    8.68588
epoch:  75  nll:    8.69977
epoch:  80  nll:    8.69636
epoch:  85  nll:    8.69916
epoch:  90  nll:    8.6969
epoch:  95  nll:    8.71021
epoch:  100 nll:    8.72561
epoch:  105 nll:    8.71369
epoch:  110 nll:    8.71723
epoch:  115 nll:    8.72388
epoch:  120 nll:    8.71293
epoch:  125 nll:    8.70667
epoch:  130 nll:    8.70341
epoch:  135 nll:    8.69929
epoch:  140 nll:    8.69793
epoch:  145 nll:    8.67705
epoch:  150 nll:    8.65372

Note: Part of this code (dataloader, discriminator, target LSTM) is based on original implementation by Lantao Yu. Many thanks to his code