The project is for conditional sequence generation, e.g., chit-chat chatbot.
The programmed training methods includes cross entropy minimization / maximum likelihood estimation (MLE)[2], REINFORCE[3], and several stepwise evaluation methods for sequential GANs, e.g., SeqGAN (w/ Monte-Carlo)[4], MaliGAN[5], REGS[6], the update method of MaskGAN[7], StepGAN[1].
The usable testing methods includes argmax, softmax samples, beam search, maximum mutual information[8].
The project is the implementation of our paper:
The StepGAN Architecture. x is an input message and y1 to y6 are infered output tokens. The details are in the paper.
pip3 install numpy argparse
run.sh
./data/
, e.g., ./data/opensubtitles/opensubtitles.txt
. Or you have to check the datapath in the codes.To run the experiments with default parameters:
$bash run.sh <GPU_ID> <TEST_TYPE> <MODEL_TYPE> <TASK_NAME>
<GPU_ID>
is used when you have multiple gpu in your computer. You can use command nvidia-smi
to check the ID of your GPUs. Otherwise, you can just set <GPU_ID>
to 0
(use the 1st GPU) or -1
(use CPU).<TEST_TYPE>
includes None
, accuracy
, realtime_argmax
, realtime_sample
, realtime_beam_search
, and realtime_MMI
, where None
means the process is training and accuracy
is only usable for the synthetic task (Counting) .<MODEL_TYPE>
can be chose from MLE
, SeqGAN
/MC-SeqGAN
(the Monte-Carlo version), MaliGAN
/MC-SeqGAN
(the Monte-Carlo version), REGS
, MaskGAN
, and StepGAN
/StepGAN-W
. In the experiment of our paper, StepGAN
or StepGAN-W
can perform better.<TASK_NAME>
can be chose from OpenSubtitles
and Counting
, or other data/task you prepare.You can change all the hyper-parameters in run.sh
. The options are listed in args.py
.
training MLE for Counting
$bash run.sh 0 None MLE Counting
training REINFORCE for Counting
$bash run.sh 0 None REINFORCE Counting
training SeqGAN for Counting with lr=5e-5 lrdecay=0.99 Gstep=1 Dstep=5
$bash run.sh 0 None SeqGAN Counting 5e-5 0.99 1 5
testing MLE for Counting
$bash run.sh 0 accuracy MLE Counting
testing SeqGAN for Counting with lr=5e-5 lrdecay=0.99 Gstep=1 Dstep=5
$bash run.sh 0 accuracy SeqGAN Counting 5e-5 0.99 1 5
testing SeqGAN for OpenSubtitles
$bash run.sh 0 realtime_argmax SeqGAN OpenSubtitles
For Counting, the suggested hyper-parameters are:
generator learning rate = 5e-5\~1e-4
discriminator learning rate = 5e-5\~1e-4
learning rate decay = 0.99\~1.00
generator training steps = 1
discriminator training steps = 5
For OpenSubtitles, the used hyper-parameters are:
generator learning rate = 1e-3 (or 5e-5 for weighted version)
discriminator learning rate = 1e-3 (or 5e-5 for weighted version)
learning rate decay = 1.00
generator training steps = 1
discriminator training steps = 5
training / testing criterions:
main.py
args.py
train_mle.py
train_gan_n_rl.py
train_utils.py
test.py
models:
seq2seq_model_comp.py
critic.py
units.py
data processing:
data_utils.py