denizyuret / AutoGrad.jl

Julia port of the Python autograd package.
Other
169 stars 26 forks source link

Multiple functions sharing some gradients #119

Closed ludvigk closed 3 years ago

ludvigk commented 4 years ago

Hi, I was wondering if there is a way to track gradients for multiple functions that partially share some computation. As an example, in GANs, the generator and discriminator loss both depend on the output of the discriminator, hence there is no need to calculate this part twice for each update. In Tensorflow, this can be done with multiple gradient tapes.

This example is taken from the Tensorflow tutorials page https://www.tensorflow.org/tutorials/generative/dcgan.

def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
ludvigk commented 4 years ago

I just realized that in a simple case like this one can just take the gradient of discriminator output separately, and use the chain rule to avoid extra computation. It won't be pretty when the functions are more entangled, but I suppose it's the best option.


I just want to add that the chain rule method doesn't work when the shared computation does not result in a scalar. I am wondering if there is a more general way of doing these calculations.