weilinie / RelGAN

Implementation of RelGAN: Relational Generative Adversarial Networks for Text Generation
MIT License
120 stars 31 forks source link

The model would generated one repeated sentences #4

Closed williamSYSU closed 5 years ago

williamSYSU commented 5 years ago

Hi,

First of all, thanks for sharing your code! I’m impressive of your solid work. However, I found some issues when I run your code under different hyper-parameters.

Main issues:

  1. Synthetic data: generating repeated sentences (extreme mode collapse) when gpre_lr=0.005, while your model behaves normal under gpre_lr=0.01.
  2. Image COCO data: generating repeated sentences when temperature=1.
  3. Reasons of generating repeated sentences:
    • Very sensitive to temperature?
    • Type of adversarial training loss?
    • Or others?
  4. Could you please explain why your model would arise the problem of generating repeated sentences? Or what modules do you think would lead to this problem?
  5. Do you have any suggestions for solving the mode collapse problem?
  6. Previous works like SeqGAN and LeakGAN don’t have extreme mode collapse (only generate one repeated sentence) even the temperature is set to 1. Do you think the temperature exceed 1 is essential to your RelGAN?

Here’s my system environment.

>> Operating system:        Ubuntu 16.04.1
>> Program environment: Virutaulenv
>> Dependencies version:
    --Tensorflow    1.5.0
    --Numpy             1.14.5
    --Scipy             1.1.0
    --NLTK              3.4
    --tqdm              4.26.0
>> NVIDIA Graphics: TITAN Xp

Here are the problems I encountered when running your code.

  1. For Synthetic data experiment, I simply change the gpre_lr from 0.01 to 0.005. After 1620 epoch adversarial training, the model only generate one repeated sentence. While it behaves normal under gpre_lr=0.01.

    • Hyper-parameters settings:
    job_id=0
    gpu_id=0
    architecture='rmc_vanilla'
    gantype='RSGAN'
    opt_type='adam'
    temperature = '2'
    d_lr = ’1e-4‘
    gadv_lr = ’1e-4‘
    mem_slots = ’1‘
    head_size = ’256‘
    num_head = ’2‘
    bs = '64'
    seed = '124'
    gpre_lr = '0.005'    # <<< only change this parameter
    hidden_dim = '32'
    seq_len = '20'
    dataset = 'oracle'
    gsteps = '1'
    dsteps = '5'
    gen_emb_dim = '32'
    dis_emb_dim = '64'
    num_rep = '64'
    sn = False
    decay = False
    adapt = 'exp'
    npre_epochs = '200'
    nadv_steps = '3000'
    ntest = '20'
    • Here’s the content of log file experiment-log-relgan.csv
    pre_gen_epoch:0, g_pre_loss: 7.8529, time: 21, nll_oracle: 10.0373, nll_gen: 7.7054
    pre_gen_epoch:10, g_pre_loss: 6.3980, time: 96, nll_oracle: 9.1713, nll_gen: 7.0273
    pre_gen_epoch:20, g_pre_loss: 4.9447, time: 95, nll_oracle: 9.0030, nll_gen: 6.8668
    pre_gen_epoch:30, g_pre_loss: 4.0855, time: 97, nll_oracle: 8.9014, nll_gen: 6.6844
    pre_gen_epoch:40, g_pre_loss: 3.5856, time: 98, nll_oracle: 8.6974, nll_gen: 6.5024
    pre_gen_epoch:50, g_pre_loss: 3.4276, time: 97, nll_oracle: 8.7271, nll_gen: 6.4082
    pre_gen_epoch:60, g_pre_loss: 3.1279, time: 97, nll_oracle: 8.5847, nll_gen: 6.0566
    pre_gen_epoch:70, g_pre_loss: 2.7486, time: 97, nll_oracle: 8.5072, nll_gen: 6.1834
    pre_gen_epoch:80, g_pre_loss: 2.6509, time: 96, nll_oracle: 8.5039, nll_gen: 6.5375
    pre_gen_epoch:90, g_pre_loss: 2.3952, time: 98, nll_oracle: 8.4369, nll_gen: 6.5055
    pre_gen_epoch:100, g_pre_loss: 2.2010, time: 96, nll_oracle: 8.3912, nll_gen: 6.2355
    pre_gen_epoch:110, g_pre_loss: 2.2762, time: 97, nll_oracle: 8.3952, nll_gen: 5.8913
    pre_gen_epoch:120, g_pre_loss: 2.1142, time: 96, nll_oracle: 8.3305, nll_gen: 5.6797
    pre_gen_epoch:130, g_pre_loss: 1.9376, time: 98, nll_oracle: 8.2759, nll_gen: 5.6209
    pre_gen_epoch:140, g_pre_loss: 1.8160, time: 100, nll_oracle: 8.2619, nll_gen: 5.7824
    pre_gen_epoch:150, g_pre_loss: 2.0162, time: 99, nll_oracle: 8.3108, nll_gen: 5.8881
    pre_gen_epoch:160, g_pre_loss: 1.7529, time: 96, nll_oracle: 8.2722, nll_gen: 6.0981
    pre_gen_epoch:170, g_pre_loss: 1.7227, time: 94, nll_oracle: 8.2649, nll_gen: 6.0026
    pre_gen_epoch:180, g_pre_loss: 1.8253, time: 96, nll_oracle: 8.3251, nll_gen: 6.3667
    pre_gen_epoch:190, g_pre_loss: 1.6638, time: 95, nll_oracle: 8.2878, nll_gen: 6.2536
    adv_step: 0, nll_oracle: 8.2616, nll_gen: 5.7822
    adv_step: 20, nll_oracle: 8.2877, nll_gen: 5.7955
    adv_step: 40, nll_oracle: 8.2676, nll_gen: 5.8229
    adv_step: 60, nll_oracle: 8.2573, nll_gen: 5.8185
    adv_step: 80, nll_oracle: 8.2349, nll_gen: 5.8068
    adv_step: 100, nll_oracle: 8.2085, nll_gen: 5.8215
    .
    .
    .
    adv_step: 1500, nll_oracle: 7.6447, nll_gen: 6.6017
    adv_step: 1520, nll_oracle: 7.6238, nll_gen: 6.6425
    adv_step: 1540, nll_oracle: 7.6210, nll_gen: 6.6755
    adv_step: 1560, nll_oracle: 7.6144, nll_gen: 6.7065
    adv_step: 1580, nll_oracle: 7.6039, nll_gen: 6.7313
    adv_step: 1600, nll_oracle: 7.5998, nll_gen: 6.7501
    adv_step: 1620, nll_oracle: 7.5982, nll_gen: 6.7718
    adv_step: 1640, nll_oracle: 7.5978, nll_gen: 6.7881
    adv_step: 1660, nll_oracle: 7.5963, nll_gen: 6.8009
    adv_step: 1680, nll_oracle: 7.5951, nll_gen: 6.8147
    adv_step: 1700, nll_oracle: 7.5922, nll_gen: 6.8249
    .
    .
    .
    adv_step: 2860, nll_oracle: 8.1041, nll_gen: 7.1184
    adv_step: 2880, nll_oracle: 8.1520, nll_gen: 7.1035
    adv_step: 2900, nll_oracle: 8.1901, nll_gen: 7.1025
    adv_step: 2920, nll_oracle: 8.2444, nll_gen: 7.0996
    adv_step: 2940, nll_oracle: 8.2802, nll_gen: 7.0863
    adv_step: 2960, nll_oracle: 8.3114, nll_gen: 7.0657
    adv_step: 2980, nll_oracle: 8.2215, nll_gen: 7.0739
    • Here’s part of the 1620th adversarial epoch’s samples from adv_samples_01620.txt. (Only generated repeated sentences)
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
    3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
  2. For Image COCO caption data, I simply change the temperature from 100 to 1. The problem of generating repeated sentences arises again. Also, the model generates diverse sentences under temperature=100.

    • Hyper-parameters settings:
    job_id=0
    gpu_id=0
    architecture='rmc_vanilla'
    gantype='RSGAN'
    opt_type='adam'
    temperature = '1'    # <<< only change this parameter
    d_lr = ’1e-4‘
    gadv_lr = ’1e-4‘
    mem_slots = ’1‘
    head_size = ’256‘
    num_head = ’2‘
    bs = '64'
    seed = '124'
    gpre_lr = '0.01'
    hidden_dim = '32'
    seq_len = '20'
    dataset = 'oracle'
    gsteps = '1'
    dsteps = '5'
    gen_emb_dim = '32'
    dis_emb_dim = '64'
    num_rep = '64'
    sn = False
    decay = False
    adapt = 'exp'
    npre_epochs = '150'
    nadv_steps = '3000'
    ntest = '20'
    • Here’s the content of log file experiment-log-relgan.csv. For saving time, I didn’t calculate bleu-3 score.
    pre_gen_epoch:0, g_pre_loss: 2.4170, nll_gen: 1.2337
    pre_gen_epoch:10, g_pre_loss: 0.7531, nll_gen: 0.7711
    pre_gen_epoch:20, g_pre_loss: 0.6419, nll_gen: 0.6634
    pre_gen_epoch:30, g_pre_loss: 0.5984, nll_gen: 0.6540
    pre_gen_epoch:40, g_pre_loss: 0.5766, nll_gen: 0.6359
    pre_gen_epoch:50, g_pre_loss: 0.5352, nll_gen: 0.6119
    pre_gen_epoch:60, g_pre_loss: 0.5106, nll_gen: 0.6105
    pre_gen_epoch:70, g_pre_loss: 0.4824, nll_gen: 0.6155
    pre_gen_epoch:80, g_pre_loss: 0.4585, nll_gen: 0.6444
    pre_gen_epoch:90, g_pre_loss: 0.4533, nll_gen: 0.6171
    pre_gen_epoch:100, g_pre_loss: 0.4309, nll_gen: 0.5942
    pre_gen_epoch:110, g_pre_loss: 0.4150, nll_gen: 0.6225
    pre_gen_epoch:120, g_pre_loss: 0.4064, nll_gen: 0.6629
    pre_gen_epoch:130, g_pre_loss: 0.4034, nll_gen: 0.6835
    pre_gen_epoch:140, g_pre_loss: 0.3912, nll_gen: 0.6581
    Start adversarial training...
    adv_step: 0, nll_gen: 0.6736
    adv_step: 20, nll_gen: 0.6762
    adv_step: 40, nll_gen: 0.6761
    adv_step: 60, nll_gen: 0.6766
    adv_step: 80, nll_gen: 0.6811
    adv_step: 100, nll_gen: 0.6894
    adv_step: 120, nll_gen: 0.6988
    adv_step: 140, nll_gen: 0.7120
    adv_step: 160, nll_gen: 0.7251
    adv_step: 180, nll_gen: 0.7389
    adv_step: 200, nll_gen: 0.7512
    adv_step: 220, nll_gen: 0.7607
    adv_step: 240, nll_gen: 0.7719
    adv_step: 260, nll_gen: 0.7800
    .
    .
    .
    adv_step: 1720, nll_gen: 0.7246
    adv_step: 1740, nll_gen: 0.7263
    adv_step: 1760, nll_gen: 0.7266
    adv_step: 1780, nll_gen: 0.7278
    adv_step: 1800, nll_gen: 0.7268
    adv_step: 1820, nll_gen: 0.7256
    adv_step: 1840, nll_gen: 0.7253
    adv_step: 1860, nll_gen: 0.7246
    adv_step: 1880, nll_gen: 0.7233
    adv_step: 1900, nll_gen: 0.7232
    adv_step: 1920, nll_gen: 0.7234
    adv_step: 1940, nll_gen: 0.7228
    adv_step: 1960, nll_gen: 0.7239
    adv_step: 1980, nll_gen: 0.7260
    • Here’s part of the 1000th adversarial epoch’s samples from adv_samples_01000.txt. (Only generated repeated sentences)
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
    a group of people are riding motorcycles on a city street . 
weilinie commented 5 years ago

Main issues:

  1. Synthetic data: generating repeated sentences (extreme mode collapse) when gpre_lr=0.005, while your model behaves normal under gpre_lr=0.01.

Right after pre-training, the nll_gen is also very high (~0.65 as you showed). So I think it’s because by setting gpre_lr=0.005 and npre_epochs=150, the pre-training may not be sufficient. A quick suggestion is to increase the npre_epochs (and/or increase the inverse temperature) to see if you can get good results.

  1. Image COCO data: generating repeated sentences when temperature=1.

It would be as expected since there is a tradeoff between sample quality and diversity, tuned by the maximum inverse temperature. In the extreme case where temperatue=1, which means no temperature control at all, the model will suffer from severe mode collapse.

  1. Reasons of generating repeated sentences: Very sensitive to temperature? Type of adversarial training loss? Or others?

For the gumbel-softmax trick, the temperature control plays a crucial role in the overall performance. So yes, it is mainly because “very sensitive to temperature”.

  1. Could you please explain why your model would arise the problem of generating repeated sentences? Or what modules do you think would lead to this problem?

The gumbel-softmax trick.

  1. Do you have any suggestions for solving the mode collapse problem?

I would recommend to improve the gumbel-softmax trick. In this work, we just use the vanilla version of the gumbel-softmax with some temperature control. I believe there is still large room for improving this module. For example, REBAR would be the first thing to try.

  1. Previous works like SeqGAN and LeakGAN don’t have extreme mode collapse (only generate one repeated sentence) even the temperature is set to 1. Do you think the temperature exceed 1 is essential to your RelGAN?

Yes, temperature>1 is essential for RelGAN from the temperature control perspective. SeqGAN and LeakGAN do not reply on temperature control as they apply REINFORCE, so they are less sensitive to the temperature.

williamSYSU commented 5 years ago

I am very grateful that you can take your time to answer my questions in detail and patiently. Your answer does help me have a better understanding of your RelGAN, while I am also confused about another thing.

According to my understanding of your code, the calculation process of g_pretrain_loss and nll_gen is exactly the same, except for the parameters of the “generator”. In fact, nll_gen is the g_pretrain_loss calculated by the "generator" whose parameters have been updated after pre-training. Therefore, the value of g_pretrain_loss and nll_gen should be close, and the value of g_pretrain_loss should be generally larger than the value of nll_gen from the training perspective. However, from the log file with gpre_lr=0.005 on Synthetic Data, the g_pretrain_loss is already small (~1.7) while the nll_gen is still large (~6.3). According the above analysis, this situation should not happen.

Is there mistake of my understanding or my analysis? Or is the way you calculate the g_pretrain_loss and the nll_gen different?

weilinie commented 5 years ago

I think the difference between g_pretrain_loss and nll_gen mainly lies in how each of them is calculated over mini-batches: For nll_gen, we fix the generator parameters and take average of g_loss over all mini-batches (please refer to nll_loss() in Nll.py). For g_pretrain_loss, however, we always first adapt the value of g_loss to each mini-batch and then take average of the adapted g_losss over all mini-batches (please refer to pre_train_epoch() in utils.py). It explains why g_pretrain_loss is lower than nll_gen.

williamSYSU commented 5 years ago

Thank you again for your answers and code :)