dragen1860 / TensorFlow-2.x-Tutorials

TensorFlow 2.x version's Tutorials and Examples, including CNN, RNN, GAN, Auto-Encoders, FasterRCNN, GPT, BERT examples, etc. TF 2.0版入门实例代码,实战教程。
6.38k stars 2.23k forks source link

关于WGAN-gp源码的两点问题 #45

Open donpromax opened 4 years ago

donpromax commented 4 years ago

在看源码的过程中发现了一点小问题

  1. wgan_train.py源码还是使用了sigmoid再做cross_entro_loss,但是WGAN应该直接返回Discrimintaror的输出logits作为loss
def d_loss_fn(generator, discriminator, batch_z, real_image):
    fake_image = generator(batch_z, training=True)
    d_fake_score = discriminator(fake_image, training=True)
    d_real_score = discriminator(real_image, training=True)

    loss = tf.reduce_mean(d_fake_score - d_real_score)
    # lambda = 10
    gp = gradient_penalty(discriminator, real_image, fake_image) * 10.

    loss = loss + gp
    return loss, gp

def g_loss_fn(generator, discriminator, batch_z):
    fake_image = generator(batch_z, training=True)
    d_fake_logits = discriminator(fake_image, training=True)
    # loss = celoss_ones(d_fake_logits)
    loss = -tf.reduce_mean(d_fake_logits)
    return loss

2.按照WGAN的要求改完logits作为loss后,我发现train起来不能收敛,经过反复检查,发现是gradient penalty的计算有些问题,将原有函数如下之后可以很好地收敛:

def gradient_penalty(discriminator, real_image, fake_image):
    batchsz = real_image.shape[0]
    # dtype caused disconvergence?
    t = tf.random.uniform([batchsz, 1, 1, 1], minval=0., maxval=1., dtype=tf.float32)
    x_hat = t * real_image + (1. - t) * fake_image
    with tf.GradientTape() as tape:
        tape.watch(x_hat)
        Dx = discriminator(x_hat, training=True)
    grads = tape.gradient(Dx, x_hat)
    slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((slopes - 1.) ** 2)
    return gp
donpromax commented 4 years ago

wgan_gp-160000 改进前:train到5W epoch左右就会发生梯度爆炸,导致generator只能产生噪声。 改进后:发挥了WGAN training稳定的特性,目前train了16W个epoch,输出还是可以稳定提升。

donpromax commented 4 years ago

其他改进:使用Deconvolution,输出放大仔细看,好像能观察到棋盘状暗纹。可能是Conv_Transpose导致的overlap。如果把discriminator改为upsampling+Conv2D的结构应该可以消除,由于该改进我还在train,具体效果还有待确认