chrisdonahue / wavegan

WaveGAN: Learn to synthesize raw audio with generative adversarial networks
MIT License
1.32k stars 282 forks source link

Keras Loss implementation #95

Closed Arqit closed 3 years ago

Arqit commented 3 years ago

Hi,

I'm currently trying to implement a wavegan in keras using the wgan-gp loss. I'm not too familiar with Tensorflow 1 and so I think my incorrect result is rooted in poor translation.

I've ensured that the models are the same and believe that my current error is in how the loss is being calculated.

Because the loss is custom I've opted to implement the loss calculation and model training as follows:

G = WaveGANGenerator()
D = WaveGANDiscriminator()
#create optimizers
G_opt = tf.keras.optimizers.Adam(learning_rate = 1e-4,beta_1=0.5,beta_2=0.9)
D_opt = tf.keras.optimizers.Adam(learning_rate = 1e-4,beta_1=0.5,beta_2=0.9)
#
batch_size = 64
writer = tf.summary.create_file_writer("/content/drive/My Drive/logs")

@tf.function
def train_step(step):

  for _ in range(5): #5 discriminator updates per generator update
    x = loader.get_batch(fps,batch_size,16384)
    z = keras.backend.random_uniform((batch_size,100))

    #train Discriminator
    with tf.GradientTape() as disc_tape:
      G_z = G(z)
      d_g_z = D(G_z)
      d_x = D(x)
      d_loss = tf.reduce_mean(d_g_z) -  tf.reduce_mean(d_x)

      alpha = keras.backend.random_uniform(shape=[batch_size, 1, 1], minval=0., maxval=1.)
      differences = G_z - x
      interpolates = x + (alpha * differences)

      with tf.GradientTape() as interp_tape:
        interp.watch(interpolates)
        d_interped = D(interpolates)

      LAMBDA = 10
      interp_grads = interp_tape.gradient(d_interped,[interpolates])[0]
      slopes = tf.sqrt(tf.reduce_sum(tf.square(interp_grads), axis=[1, 2]))
      gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)
      d_loss += LAMBDA * gradient_penalty

    discriminator_gradients = disc_tape.gradient(d_loss,D.trainable_weights)
    D_opt.apply_gradients(zip(d_g,D.trainable_weights))

  z = keras.backend.random_uniform((batch_size,100))
  #train Generator
  with tf.GradientTape() as gen_tape:
    predictions = D(G(z))
    g_loss = -tf.reduce_mean(predictions)

  generator_gradients = gen_tape.gradient(g_loss,G.trainable_weights)
  G_opt.apply_gradients(zip(generator_gradients,G.trainable_weights))

  # save summaries every 10 steps only
  if step % 10 == 0:
    with writer.as_default():
      G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
      x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
      tf.summary.audio('x',x, sample_rate = 16000, step=step)
      tf.summary.audio('G_z',G_z, sample_rate = 16000, step=step)
      tf.summary.scalar('G_loss', g_loss,step=step)
      tf.summary.scalar('D_loss', d_loss,step=step)
      tf.summary.histogram('x_rms_batch', x_rms,step=step)
      tf.summary.histogram('G_z_rms_batch', G_z_rms,step=step)
      tf.summary.scalar('x_rms', tf.reduce_mean(x_rms),step=step)
      tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms),step=step)

  return d_loss, g_loss

keras_wavegan_loss

Do the graphs give an intuitive idea as to what could be going wrong or is my implementation altogether wrong?