CR-Gjx / LeakGAN

The codes of paper "Long Text Generation via Adversarial Training with Leaked Information" on AAAI 2018. Text generation using GAN and Hierarchical Reinforcement Learning.
https://arxiv.org/abs/1709.08624
577 stars 181 forks source link

Rewards don't change while training #17

Closed Seraphli closed 6 years ago

Seraphli commented 6 years ago

When I run the code, I try to print the mean value of rewards. Strange is, the mean of rewards didn't change while training. Code snippet I used is here. I just add a print under line 285:

samples = leakgan.generate(sess,1.0,1)
rewards = get_reward(leakgan, discriminator,sess, samples, 4, dis_dropout_keep_prob)
print('rewards: ', np.mean(rewards))
feed = {leakgan.x: samples, leakgan.reward: rewards,leakgan.drop_out:1.0}
_,_,g_loss,w_loss = sess.run([leakgan.manager_updates,leakgan.worker_updates,leakgan.goal_loss,leakgan.worker_loss], feed_dict=feed)
print('total_batch: ', total_batch, "  ",g_loss,"  ", w_loss)

The output is here:

(64, ?, 1720)
(?, ?, 1720)
2018-08-28 18:00:47.803815: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-08-28 18:00:47.955611: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1356] Found device 0 with properties: 
name: GeForce GTX 1080 major: 6 minor: 1 memoryClockRate(GHz): 1.8225
pciBusID: 0000:03:00.0
totalMemory: 7.92GiB freeMemory: 5.43GiB
2018-08-28 18:00:47.955644: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1435] Adding visible gpu devices: 0
2018-08-28 18:00:48.257207: I tensorflow/core/common_runtime/gpu/gpu_device.cc:923] Device interconnect StreamExecutor with strength 1 edge matrix:
2018-08-28 18:00:48.257245: I tensorflow/core/common_runtime/gpu/gpu_device.cc:929]      0 
2018-08-28 18:00:48.257253: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 0:   N 
2018-08-28 18:00:48.257465: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 4057 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080, pci bus id: 0000:03:00.0, compute capability: 6.1)
[[1395 2108 1587 ... 4713 4964  369]
 [3043 2382 2235 ... 1873   40 3757]
 [1811  411 4354 ...  670  492 3540]
 ...
 [4757 2083 4780 ... 2464 1251 1335]
 [ 571  679 2516 ... 3131 1198 2000]
 [1581  985  414 ... 3967 1530  983]]
('epoch:', 0, '  ')
ERROR:tensorflow:Couldn't match files for checkpoint ./ckpts/leakgan_pre
None
Start pre-training discriminator...
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 9.732324)
('Groud-Truth:', 5.7501173)
('pre-train epoch ', 5, 'test_loss ', 9.205654)
('Groud-Truth:', 5.751481)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.669848)
('Groud-Truth:', 5.7524385)
('pre-train epoch ', 5, 'test_loss ', 8.304642)
('Groud-Truth:', 5.7588925)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.217238)
('Groud-Truth:', 5.7520146)
('pre-train epoch ', 5, 'test_loss ', 8.068525)
('Groud-Truth:', 5.7564244)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.088984)
('Groud-Truth:', 5.7408185)
('pre-train epoch ', 5, 'test_loss ', 8.123215)
('Groud-Truth:', 5.7433057)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.125778)
('Groud-Truth:', 5.7547755)
('pre-train epoch ', 5, 'test_loss ', 8.148893)
('Groud-Truth:', 5.7508097)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.230867)
('Groud-Truth:', 5.75031)
('pre-train epoch ', 5, 'test_loss ', 8.225234)
('Groud-Truth:', 5.7536063)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.287518)
('Groud-Truth:', 5.753323)
('pre-train epoch ', 5, 'test_loss ', 8.347004)
('Groud-Truth:', 5.7483764)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.36616)
('Groud-Truth:', 5.761182)
('pre-train epoch ', 5, 'test_loss ', 8.400379)
('Groud-Truth:', 5.7486343)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.394658)
('Groud-Truth:', 5.739507)
('pre-train epoch ', 5, 'test_loss ', 8.400157)
('Groud-Truth:', 5.749277)
Start pre-training...
('pre-train epoch ', 0, 'test_loss ', 8.370322)
('Groud-Truth:', 5.7427726)
('pre-train epoch ', 5, 'test_loss ', 8.384528)
('Groud-Truth:', 5.7542734)
#########################################################################
Start Adversarial Training...
('rewards: ', 0.12695181503855285)
('total_batch: ', 0, '  ', -0.08002976, '  ', 2.572021)
('total_batch: ', 0, 'test_loss: ', 8.316683)
('Groud-Truth:', 5.7433925)
('rewards: ', 0.12695181503855285)
('total_batch: ', 1, '  ', -0.078070164, '  ', 2.4859025)
('rewards: ', 0.12695181503855285)
('total_batch: ', 2, '  ', -0.076734476, '  ', 2.4410596)
('rewards: ', 0.12695181503855285)
('total_batch: ', 3, '  ', -0.07695893, '  ', 2.448696)
('rewards: ', 0.12695181503855285)
('total_batch: ', 4, '  ', -0.075939886, '  ', 2.4610584)
('rewards: ', 0.12695181503855285)
('total_batch: ', 5, '  ', -0.07431904, '  ', 2.3714132)
('total_batch: ', 5, 'test_loss: ', 8.068953)
('Groud-Truth:', 5.7565346)
('rewards: ', 0.12695181503855285)
('total_batch: ', 6, '  ', -0.07445905, '  ', 2.353417)
('rewards: ', 0.12695181503855285)
('total_batch: ', 7, '  ', -0.0741186, '  ', 2.3683317)
('rewards: ', 0.12695181503855285)
('total_batch: ', 8, '  ', -0.07297791, '  ', 2.277438)
('rewards: ', 0.12695181503855285)
('total_batch: ', 9, '  ', -0.072162904, '  ', 2.1770558)
('rewards: ', 0.12695181503855285)
('total_batch: ', 10, '  ', -0.072940074, '  ', 2.2332137)
('total_batch: ', 10, 'test_loss: ', 7.715553)
('Groud-Truth:', 5.749097)
('rewards: ', 0.12695181503855285)
('total_batch: ', 11, '  ', -0.072791696, '  ', 2.2014272)
('rewards: ', 0.12695181503855285)
('total_batch: ', 12, '  ', -0.071583025, '  ', 2.1169)
('rewards: ', 0.12695181503855285)
('total_batch: ', 13, '  ', -0.07056489, '  ', 2.1018846)
('rewards: ', 0.12695181503855285)
('total_batch: ', 14, '  ', -0.06847201, '  ', 1.9985498)
('rewards: ', 0.12695181503855285)
('total_batch: ', 15, '  ', -0.06597161, '  ', 1.9304808)
('total_batch: ', 15, 'test_loss: ', 7.7085085)
('Groud-Truth:', 5.7584386)
('rewards: ', 0.12695181503855285)
('total_batch: ', 16, '  ', -0.066821866, '  ', 1.8718865)
('rewards: ', 0.12695181503855285)
('total_batch: ', 17, '  ', -0.0675878, '  ', 1.90795)
('rewards: ', 0.12695181503855285)
('total_batch: ', 18, '  ', -0.06887313, '  ', 2.014595)
('rewards: ', 0.12695181503855285)
('total_batch: ', 19, '  ', -0.06716676, '  ', 1.9232063)
('rewards: ', 0.12695181503855285)
('total_batch: ', 20, '  ', -0.06863812, '  ', 1.9305614)
('total_batch: ', 20, 'test_loss: ', 7.3769546)
('Groud-Truth:', 5.7537217)
('rewards: ', 0.12695181503855285)
('total_batch: ', 21, '  ', -0.067585774, '  ', 1.8494338)
('rewards: ', 0.12695181503855285)
('total_batch: ', 22, '  ', -0.06702501, '  ', 1.8630661)
('rewards: ', 0.12695181503855288)
('total_batch: ', 23, '  ', -0.06764726, '  ', 1.8600069)
('rewards: ', 0.12695181503855285)
('total_batch: ', 24, '  ', -0.06668289, '  ', 1.8348721)
('rewards: ', 0.12695181503855285)
('total_batch: ', 25, '  ', -0.06396266, '  ', 1.6727743)
('total_batch: ', 25, 'test_loss: ', 7.459133)
('Groud-Truth:', 5.7533226)
('rewards: ', 0.12695181503855285)
('total_batch: ', 26, '  ', -0.063976176, '  ', 1.785656)
('rewards: ', 0.12695181503855285)
('total_batch: ', 27, '  ', -0.06395789, '  ', 1.832536)
('rewards: ', 0.12695181503855285)
('total_batch: ', 28, '  ', -0.061984994, '  ', 1.6773611)
('rewards: ', 0.12695181503855285)
('total_batch: ', 29, '  ', -0.062187918, '  ', 1.7238611)
('rewards: ', 0.12695181503855285)
('total_batch: ', 30, '  ', -0.061822247, '  ', 1.6979691)
('total_batch: ', 30, 'test_loss: ', 7.2630606)
('Groud-Truth:', 5.752266)
('rewards: ', 0.12695181503855285)
('total_batch: ', 31, '  ', -0.062000763, '  ', 1.6922305)
('rewards: ', 0.12695181503855285)
('total_batch: ', 32, '  ', -0.060837973, '  ', 1.6040627)
('rewards: ', 0.12695181503855285)
('total_batch: ', 33, '  ', -0.06359503, '  ', 1.6852735)
('rewards: ', 0.12695181503855285)
('total_batch: ', 34, '  ', -0.06405149, '  ', 1.6879333)
('rewards: ', 0.12695181503855285)
('total_batch: ', 35, '  ', -0.06216966, '  ', 1.6790202)
('total_batch: ', 35, 'test_loss: ', 7.3644423)
('Groud-Truth:', 5.7475533)
('rewards: ', 0.12695181503855285)
('total_batch: ', 36, '  ', -0.062462743, '  ', 1.6497834)
('rewards: ', 0.12695181503855285)
('total_batch: ', 37, '  ', -0.0619183, '  ', 1.6184376)
('rewards: ', 0.12695181503855285)
('total_batch: ', 38, '  ', -0.060982812, '  ', 1.6155812)
('rewards: ', 0.12695181503855285)
('total_batch: ', 39, '  ', -0.06237963, '  ', 1.6360941)
('rewards: ', 0.12695181503855285)
('total_batch: ', 40, '  ', -0.06056885, '  ', 1.5517352)
('total_batch: ', 40, 'test_loss: ', 7.1818643)
('Groud-Truth:', 5.7501273)
('rewards: ', 0.12695181503855285)
('total_batch: ', 41, '  ', -0.059358098, '  ', 1.5096384)
('rewards: ', 0.12695181503855285)
('total_batch: ', 42, '  ', -0.06264005, '  ', 1.6303447)
('rewards: ', 0.12695181503855288)
('total_batch: ', 43, '  ', -0.057175227, '  ', 1.4948953)
('rewards: ', 0.12695181503855285)
('total_batch: ', 44, '  ', -0.05770105, '  ', 1.4612815)
('rewards: ', 0.12695181503855285)
('total_batch: ', 45, '  ', -0.05525694, '  ', 1.5415224)
('total_batch: ', 45, 'test_loss: ', 7.37787)
('Groud-Truth:', 5.765032)
('rewards: ', 0.12695181503855285)
('total_batch: ', 46, '  ', -0.05773246, '  ', 1.5037141)
('rewards: ', 0.12695181503855285)
('total_batch: ', 47, '  ', -0.05908689, '  ', 1.5346103)
('rewards: ', 0.12695181503855285)
('total_batch: ', 48, '  ', -0.06146999, '  ', 1.5455922)
('rewards: ', 0.12695181503855285)
('total_batch: ', 49, '  ', -0.059052933, '  ', 1.454091)
('rewards: ', 0.12695181503855285)
('total_batch: ', 50, '  ', -0.060394067, '  ', 1.424764)
('total_batch: ', 50, 'test_loss: ', 6.8873477)
('Groud-Truth:', 5.7444544)
('rewards: ', 0.12695181503855285)
('total_batch: ', 51, '  ', -0.062466055, '  ', 1.4687254)
('rewards: ', 0.12695181503855285)
('total_batch: ', 52, '  ', -0.060713746, '  ', 1.4554237)
('rewards: ', 0.12695181503855285)
('total_batch: ', 53, '  ', -0.06052899, '  ', 1.3958138)
('rewards: ', 0.12695181503855285)
('total_batch: ', 54, '  ', -0.0619392, '  ', 1.4287583)
('rewards: ', 0.12695181503855285)
('total_batch: ', 55, '  ', -0.0621605, '  ', 1.4351918)
('total_batch: ', 55, 'test_loss: ', 6.9499106)
('Groud-Truth:', 5.7546964)
('rewards: ', 0.12695181503855285)
('total_batch: ', 56, '  ', -0.062033392, '  ', 1.5352631)
('rewards: ', 0.12695181503855285)
('total_batch: ', 57, '  ', -0.059066363, '  ', 1.3914597)
('rewards: ', 0.12695181503855285)
('total_batch: ', 58, '  ', -0.061237067, '  ', 1.5043005)
('rewards: ', 0.12695181503855285)
('total_batch: ', 59, '  ', -0.06429994, '  ', 1.5481747)
('rewards: ', 0.12695181503855285)
('total_batch: ', 60, '  ', -0.06365656, '  ', 1.4762405)
('total_batch: ', 60, 'test_loss: ', 7.0070457)
('Groud-Truth:', 5.74922)
('rewards: ', 0.12695181503855285)
('total_batch: ', 61, '  ', -0.063503675, '  ', 1.5692905)
('rewards: ', 0.12695181503855285)
('total_batch: ', 62, '  ', -0.06233556, '  ', 1.520729)
('rewards: ', 0.12695181503855285)
('total_batch: ', 63, '  ', -0.062474646, '  ', 1.5496463)
('rewards: ', 0.12695181503855285)
('total_batch: ', 64, '  ', -0.061706282, '  ', 1.5579293)
('rewards: ', 0.12695181503855285)
('total_batch: ', 65, '  ', -0.05994014, '  ', 1.4769189)
('total_batch: ', 65, 'test_loss: ', 7.2139378)
('Groud-Truth:', 5.746189)
('rewards: ', 0.12695181503855285)
('total_batch: ', 66, '  ', -0.059200168, '  ', 1.496091)
('rewards: ', 0.12695181503855285)
('total_batch: ', 67, '  ', -0.061548878, '  ', 1.482251)
('rewards: ', 0.12695181503855285)
('total_batch: ', 68, '  ', -0.060286995, '  ', 1.3487338)
('rewards: ', 0.12695181503855285)
('total_batch: ', 69, '  ', -0.06037658, '  ', 1.450922)
('rewards: ', 0.12695181503855285)
('total_batch: ', 70, '  ', -0.060411155, '  ', 1.4245926)
('total_batch: ', 70, 'test_loss: ', 6.9322367)
('Groud-Truth:', 5.766845)
('rewards: ', 0.12695181503855285)
('total_batch: ', 71, '  ', -0.061025728, '  ', 1.4254792)
('rewards: ', 0.12695181503855285)
('total_batch: ', 72, '  ', -0.06147579, '  ', 1.5197109)
('rewards: ', 0.12695181503855288)
('total_batch: ', 73, '  ', -0.061548937, '  ', 1.4503106)
('rewards: ', 0.12695181503855285)
('total_batch: ', 74, '  ', -0.06145633, '  ', 1.4222437)
('rewards: ', 0.12695181503855285)
('total_batch: ', 75, '  ', -0.060437728, '  ', 1.4054346)
('total_batch: ', 75, 'test_loss: ', 6.9886317)
('Groud-Truth:', 5.7475686)
('rewards: ', 0.12695181503855285)
('total_batch: ', 76, '  ', -0.050933838, '  ', 1.2105244)
('rewards: ', 0.12695181503855285)
('total_batch: ', 77, '  ', -0.051188022, '  ', 1.2495371)
('rewards: ', 0.12695181503855285)
('total_batch: ', 78, '  ', -0.05532831, '  ', 1.2443336)
('rewards: ', 0.12695181503855285)
('total_batch: ', 79, '  ', -0.053034663, '  ', 1.177541)
('rewards: ', 0.12695181503855285)
('total_batch: ', 80, '  ', -0.05738743, '  ', 1.3191526)
('total_batch: ', 80, 'test_loss: ', 6.9316854)
('Groud-Truth:', 5.7573166)
('rewards: ', 0.12695181503855285)
('total_batch: ', 81, '  ', -0.057465132, '  ', 1.259404)
('rewards: ', 0.12695181503855285)
('total_batch: ', 82, '  ', -0.060087766, '  ', 1.3599504)
('rewards: ', 0.12695181503855285)
('total_batch: ', 83, '  ', -0.061464902, '  ', 1.3595842)
('rewards: ', 0.12695181503855285)
('total_batch: ', 84, '  ', -0.06157413, '  ', 1.3710985)
('rewards: ', 0.12695181503855285)
('total_batch: ', 85, '  ', -0.055594403, '  ', 1.2769436)
('total_batch: ', 85, 'test_loss: ', 6.83688)
('Groud-Truth:', 5.758091)
('rewards: ', 0.12695181503855285)
('total_batch: ', 86, '  ', -0.05882445, '  ', 1.3580688)
('rewards: ', 0.12695181503855285)
('total_batch: ', 87, '  ', -0.06124729, '  ', 1.4289982)
('rewards: ', 0.12695181503855285)
('total_batch: ', 88, '  ', -0.06239751, '  ', 1.3973529)
('rewards: ', 0.12695181503855285)
('total_batch: ', 89, '  ', -0.062367536, '  ', 1.3817501)
('rewards: ', 0.12695181503855285)
('total_batch: ', 90, '  ', -0.055310644, '  ', 1.3466204)
('total_batch: ', 90, 'test_loss: ', 7.1743336)
('Groud-Truth:', 5.7502723)
('rewards: ', 0.12695181503855285)
('total_batch: ', 91, '  ', -0.058700252, '  ', 1.3156176)
('rewards: ', 0.12695181503855285)
('total_batch: ', 92, '  ', -0.057320654, '  ', 1.2937081)
('rewards: ', 0.12695181503855285)
('total_batch: ', 93, '  ', -0.0583059, '  ', 1.3408545)
('rewards: ', 0.12695181503855285)
('total_batch: ', 94, '  ', -0.058293354, '  ', 1.3830068)
('rewards: ', 0.12695181503855285)
('total_batch: ', 95, '  ', -0.05834304, '  ', 1.2609603)
('total_batch: ', 95, 'test_loss: ', 7.2239175)
('Groud-Truth:', 5.7552676)
('rewards: ', 0.12695181503855285)
('total_batch: ', 96, '  ', -0.059853982, '  ', 1.2576574)
('rewards: ', 0.12695181503855285)
('total_batch: ', 97, '  ', -0.062316217, '  ', 1.2934114)
('rewards: ', 0.12695181503855285)
('total_batch: ', 98, '  ', -0.062229063, '  ', 1.2857419)
('rewards: ', 0.12695181503855285)
('total_batch: ', 99, '  ', -0.06375175, '  ', 1.3294318)
('rewards: ', 0.12695181503855285)
('total_batch: ', 100, '  ', -0.062345475, '  ', 1.366978)
('total_batch: ', 100, 'test_loss: ', 7.2802305)
('Groud-Truth:', 5.7551565)
('rewards: ', 0.12695181503855285)
('total_batch: ', 101, '  ', -0.06322038, '  ', 1.4014984)
('rewards: ', 0.12695181503855285)
('total_batch: ', 102, '  ', -0.06465603, '  ', 1.4416918)
('rewards: ', 0.12695181503855285)
('total_batch: ', 103, '  ', -0.06457795, '  ', 1.4671445)
('rewards: ', 0.12695181503855285)
('total_batch: ', 104, '  ', -0.06226616, '  ', 1.3307029)
('rewards: ', 0.12695181503855285)
('total_batch: ', 105, '  ', -0.060927063, '  ', 1.3127974)
('total_batch: ', 105, 'test_loss: ', 7.2767563)
('Groud-Truth:', 5.7502513)
('rewards: ', 0.12695181503855285)
('total_batch: ', 106, '  ', -0.063435964, '  ', 1.3460172)
('rewards: ', 0.12695181503855285)
('total_batch: ', 107, '  ', -0.06490766, '  ', 1.406265)
('rewards: ', 0.12695181503855285)
('total_batch: ', 108, '  ', -0.066277735, '  ', 1.4620061)
('rewards: ', 0.12695181503855285)
('total_batch: ', 109, '  ', -0.06379251, '  ', 1.4881936)
('rewards: ', 0.12695181503855285)
('total_batch: ', 110, '  ', -0.06467301, '  ', 1.475858)
('total_batch: ', 110, 'test_loss: ', 7.504943)
('Groud-Truth:', 5.739826)
('rewards: ', 0.12695181503855285)
('total_batch: ', 111, '  ', -0.06350755, '  ', 1.4781302)
('rewards: ', 0.12695181503855285)
('total_batch: ', 112, '  ', -0.065257765, '  ', 1.4885843)
('rewards: ', 0.12695181503855285)
('total_batch: ', 113, '  ', -0.066707864, '  ', 1.3226095)
('rewards: ', 0.12695181503855285)
('total_batch: ', 114, '  ', -0.06552046, '  ', 1.4152005)
('rewards: ', 0.12695181503855285)
('total_batch: ', 115, '  ', -0.06892277, '  ', 1.5064849)
('total_batch: ', 115, 'test_loss: ', 7.555312)
('Groud-Truth:', 5.7560487)
Seraphli commented 6 years ago

I also run the code SeqGAN. The rewards in SeqGAN experiment are changing.

LeeJuly30 commented 6 years ago

Because of the Bootstrapped Rescaled Activation trick

Seraphli commented 6 years ago

@LeeJuly30 Could you give more explanation? I can't understand why Bootstrapped Rescaled Activation trick causes this problem. And how is that even trained with fixed rewards?

LeeJuly30 commented 6 years ago

In this paper,

For each timestep t, we rescale the t-th column vector R

if you look at the formula, you will find that the reward after rescaled only depends on the batch size B and it's rank, so expectation and variance of reward within a mini-batch won't change. By doing this,

the rescale activation serves as a value stabilizer that is helpful for algorithms that are sensitive in numerical variance

Seraphli commented 6 years ago

For a mini-batch

@LeeJuly30 But the results above ran for several mini-batch (total batch 115), and the rewards didn't change at all. So when will it change?

LeeJuly30 commented 6 years ago

The reward will change, it is the expectation and variance of reward that won't change.

Seraphli commented 6 years ago

@LeeJuly30 Thanks for tips. I find the rewards do change and the magnitude of rewards before rescaling change a lot.