Pascalson / Conditional-Seq-GANs

GANs for Conditional Sequence Generation. Tensorflow. Including the code of paper "Improving Conditional Sequence Generative Adversarial Networks by Stepwise Evaluation" IEEE/ACM TASLP, 2019.
MIT License
34 stars 6 forks source link

Condition Sequence GANs

The project is for conditional sequence generation, e.g., chit-chat chatbot.

The programmed training methods includes cross entropy minimization / maximum likelihood estimation (MLE)[2], REINFORCE[3], and several stepwise evaluation methods for sequential GANs, e.g., SeqGAN (w/ Monte-Carlo)[4], MaliGAN[5], REGS[6], the update method of MaskGAN[7], StepGAN[1].

The usable testing methods includes argmax, softmax samples, beam search, maximum mutual information[8].

The project is the implementation of our paper:

StepGAN Architecture

The StepGAN Architecture. x is an input message and y1 to y6 are infered output tokens. The details are in the paper.

Requirements

Usage

To run the experiments with default parameters:

$bash run.sh <GPU_ID> <TEST_TYPE> <MODEL_TYPE> <TASK_NAME>

You can change all the hyper-parameters in run.sh. The options are listed in args.py.

Examples

training MLE for Counting

$bash run.sh 0 None MLE Counting

training REINFORCE for Counting

$bash run.sh 0 None REINFORCE Counting

training SeqGAN for Counting with lr=5e-5 lrdecay=0.99 Gstep=1 Dstep=5

$bash run.sh 0 None SeqGAN Counting 5e-5 0.99 1 5

testing MLE for Counting

$bash run.sh 0 accuracy MLE Counting

testing SeqGAN for Counting with lr=5e-5 lrdecay=0.99 Gstep=1 Dstep=5

$bash run.sh 0 accuracy SeqGAN Counting 5e-5 0.99 1 5

testing SeqGAN for OpenSubtitles

$bash run.sh 0 realtime_argmax SeqGAN OpenSubtitles

Hyperparameters

For Counting, the suggested hyper-parameters are:

    generator learning rate = 5e-5\~1e-4
    discriminator learning rate = 5e-5\~1e-4
    learning rate decay = 0.99\~1.00
    generator training steps = 1
    discriminator training steps = 5

For OpenSubtitles, the used hyper-parameters are:

    generator learning rate = 1e-3 (or 5e-5 for weighted version)
    discriminator learning rate = 1e-3 (or 5e-5 for weighted version)
    learning rate decay = 1.00
    generator training steps = 1
    discriminator training steps = 5

Program Structures

training / testing criterions:

models:

data processing:

References

  1. Tuan, Yi-Lin, and Hung-Yi Lee. "Improving conditional sequence generative adversarial networks by stepwise evaluation." IEEE/ACM Transactions on Audio, Speech, and Language Processing 27.4 (2019): 788-798.
  2. Vinyals, Oriol, and Quoc Le. "A neural conversational model." arXiv preprint arXiv:1506.05869 (2015).
  3. Ranzato, Marc'Aurelio, et al. "Sequence level training with recurrent neural networks." ICLR 2016.
  4. Yu, Lantao, et al. "Seqgan: Sequence generative adversarial nets with policy gradient." AAAI 2017.
  5. Che, Tong, et al. "Maximum-likelihood augmented discrete generative adversarial networks." arXiv preprint arXiv:1702.07983 (2017).
  6. Li, Jiwei, et al. "Adversarial Learning for Neural Dialogue Generation." EMNLP 2017.
  7. Fedus, William, Ian Goodfellow, and Andrew M. Dai. "MaskGAN: better text generation via filling in the_." ICLR 2018
  8. Li, Jiwei, et al. "A Diversity-Promoting Objective Function for Neural Conversation Models." NAACL-HLT 2016.