igul222 / improved_wgan_training

Code for reproducing experiments in "Improved Training of Wasserstein GANs"
MIT License
2.35k stars 668 forks source link

Does the ideal gradient penalty loss point exist? #76

Closed li-zemin closed 6 years ago

li-zemin commented 6 years ago

Hello, I trained a WGAN-GP on the 3D voxel dataset, and my loss code is:

########## Gradient penalty calculations ##############
G_fake_ = tf.reshape(G_fake, shape=[args.batchsize, -1])
real_models_ = tf.reshape(real_models, shape=[args.batchsize, -1])
alpha = tf.random_uniform(shape=[args.batchsize, 1], minval=0., maxval=1.)
difference = G_fake_ - real_models_
inter = []
for i in range(args.batchsize):
    inter.append(difference[i] * alpha[i])
inter = tf.unstack(inter)
interpolates_ = real_models_ + inter
interpolates = tf.reshape(interpolates_, shape=[args.batchsize, output_size, output_size, output_size])
gradients = tf.gradients(
    discriminator(voxel_input=interpolates,  output_size=output_size, batch_size=args.batchsize, improved=True, is_train=False, reuse=True)[1],
    [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2, 3]))
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)

########### Loss calculations #########################
w = 0.85
mse = tf.reduce_mean(tf.reduce_mean(w * real_models * tf.log(G_fake + 1e-8) + (1 - w) * (1 - real_models) * tf.log(1-G_fake + 1e-8), reduction_indices=[1, 2, 3]))
recon_loss = -mse
d_r_loss = -tf.reduce_mean(D_legit)
d_f_loss = tf.reduce_mean(D_fake)
d_gp_loss = 10. * gradient_penalty
d_loss = d_r_loss + d_f_loss + d_gp_loss  # 判别器误差
g_loss = -tf.reduce_mean(D_fake) * 0.5 + recon_loss * 1000  # 生成器误差

I trained the generator for one time, and trained the discriminator for five times in a iteration. After training for 20k iterations, I found that, the distance between the fake_loss and real_loss is stable at a fixed value, about 30. Although the d_loss which consists the d_f_loss, d_r_loss, and gradient_penalty_loss is decreasing in a much slow rate, the distance between the fake_loss and real_loss is almost invariable, only the gradient_penalty_loss is decreasing to less than 10. And as the gradient_penalty_loss with the lambda 10 and the ideal the gradient is 1, I think the desired the gradient_penalty_loss is 10. So I think although the d_loss is decreasing, the generator is not updated, as the d_r_loss + d_fake_loss is stable. How do you like it? Should I stop my training early or train the generator more times? Here is my part of training log.

Epoch: [36/1500] [1109/1606] time: 2.7238, d_loss: -27.6084, d_r: 223.7334, d_f: -260.7666, d_gp: 9.4248, g_loss: 168.3651, r_loss: 0.0365
Epoch: [36/1500] [1110/1606] time: 1.5762, d_loss: -26.0944, d_r: 233.6555, d_f: -270.5552, d_gp: 10.8053, g_loss: 168.3651, r_loss: 0.0449
Epoch: [36/1500] [1111/1606] time: 1.5797, d_loss: -27.1066, d_r: 214.9756, d_f: -251.3403, d_gp: 9.2582, g_loss: 168.3651, r_loss: 0.0469
Epoch: [36/1500] [1112/1606] time: 1.5777, d_loss: -28.6665, d_r: 231.7237, d_f: -270.2897, d_gp: 9.8996, g_loss: 168.3651, r_loss: 0.0438
Epoch: [36/1500] [1113/1606] time: 1.5862, d_loss: -30.7233, d_r: 242.5401, d_f: -284.5739, d_gp: 11.3105, g_loss: 168.3651, r_loss: 0.0441
Epoch: [36/1500] [1114/1606] time: 2.7297, d_loss: -27.7861, d_r: 242.9819, d_f: -280.5803, d_gp: 9.8123, g_loss: 179.5107, r_loss: 0.0444
Epoch: [36/1500] [1115/1606] time: 1.5874, d_loss: -30.1238, d_r: 235.1404, d_f: -275.1939, d_gp: 9.9298, g_loss: 179.5107, r_loss: 0.0410
Epoch: [36/1500] [1116/1606] time: 1.5807, d_loss: -30.6379, d_r: 258.5479, d_f: -301.9879, d_gp: 12.8022, g_loss: 179.5107, r_loss: 0.0438
Epoch: [36/1500] [1117/1606] time: 1.5841, d_loss: -30.1702, d_r: 234.5535, d_f: -275.1239, d_gp: 10.4002, g_loss: 179.5107, r_loss: 0.0424
Epoch: [36/1500] [1118/1606] time: 1.5880, d_loss: -31.2223, d_r: 246.4716, d_f: -289.9604, d_gp: 12.2665, g_loss: 179.5107, r_loss: 0.0442
Epoch: [36/1500] [1119/1606] time: 2.7268, d_loss: -29.5199, d_r: 260.3143, d_f: -300.7540, d_gp: 10.9198, g_loss: 178.7275, r_loss: 0.041

Epoch: [165/1500] [1411/1606] time: 1.7041, d_loss: -16.3198, d_r: 175.1243, d_f: -197.0403, d_gp: 5.5962, g_loss: 126.1464, r_loss: 0.0260
Epoch: [165/1500] [1412/1606] time: 1.6986, d_loss: -16.4496, d_r: 167.0591, d_f: -187.4827, d_gp: 3.9740, g_loss: 126.1464, r_loss: 0.0261
Epoch: [165/1500] [1413/1606] time: 1.6954, d_loss: -17.3041, d_r: 177.2479, d_f: -198.3689, d_gp: 3.8169, g_loss: 126.1464, r_loss: 0.0265
Epoch: [165/1500] [1414/1606] time: 1.6957, d_loss: -17.7293, d_r: 188.3765, d_f: -211.0421, d_gp: 4.9362, g_loss: 126.1464, r_loss: 0.0268
Epoch: [165/1500] [1415/1606] time: 2.9074, d_loss: -17.7915, d_r: 179.3470, d_f: -202.2335, d_gp: 5.0950, g_loss: 125.2281, r_loss: 0.0274
Epoch: [165/1500] [1416/1606] time: 1.6955, d_loss: -18.2220, d_r: 170.6599, d_f: -193.3871, d_gp: 4.5052, g_loss: 125.2281, r_loss: 0.0251
Epoch: [165/1500] [1417/1606] time: 1.7053, d_loss: -19.9087, d_r: 185.1930, d_f: -209.7985, d_gp: 4.6968, g_loss: 125.2281, r_loss: 0.0248
Epoch: [165/1500] [1418/1606] time: 1.7191, d_loss: -18.3079, d_r: 191.1158, d_f: -215.6472, d_gp: 6.2235, g_loss: 125.2281, r_loss: 0.0262
Epoch: [165/1500] [1419/1606] time: 1.7017, d_loss: -18.4072, d_r: 179.7151, d_f: -202.2856, d_gp: 4.1633, g_loss: 125.2281, r_loss: 0.0234