google / objax

Apache License 2.0
769 stars 77 forks source link

Computing gradients for a generator and discriminator #90

Closed bilal2vec closed 4 years ago

bilal2vec commented 4 years ago

Hi,

I'm trying to get a working dcgan implementation in objax. Since the discriminator and generator need to be optimized separately, I'm taking the gradients wrt to both modules in two different functions like this:

    def d_loss(x, z):
        d_loss_real = objax.functional.loss.sigmoid_cross_entropy_logits(
            discriminator(x, training=True), jnp.ones([x.shape[0], 1])).mean()

        fake_img = generator(z, training=False)
#        fake_img = generator(z, training=True)
        d_loss_fake = objax.functional.loss.sigmoid_cross_entropy_logits(
            discriminator(fake_img, training=True), jnp.zeros([x.shape[0], 1])).mean()

        d_loss = d_loss_real + d_loss_fake

        return d_loss

    def g_loss(x, z):
        fake_img = generator(z, training=True)
        return objax.functional.loss.sigmoid_cross_entropy_logits(discriminator(fake_img, training=False), jnp.ones([x.shape[0], 1])).mean()
#        return objax.functional.loss.sigmoid_cross_entropy_logits(discriminator(fake_img, training=True), jnp.ones([x.shape[0], 1])).mean()

    d_gv = objax.GradValues(d_loss, discriminator.vars())
    g_gv = objax.GradValues(g_loss, generator.vars())

Would this the preferred way of doing this, or is there a way of returning two values (the loss for both the generator and discriminator) in one function and then computing the gradients of the generator and discriminator loss separately (e.g. d_loss wrt discriminator.vars())?

The code runs, and both parts of the gan seem to train and their losses go down, but the discriminator's loss quickly converges to 0 which I'm guessing is caused by having to comment out the lines above to set training=False (which would prevent batchnorm from using the current batch's stats and might be the cause of the discriminator converging so quickly) since it causes a IndexError: tuple index out of range error.

colab notebook: https://colab.research.google.com/drive/1WTBKHqZWAg-TpXJmZWOn_7VCVOZsDP2F?usp=sharing

david-berthelot commented 4 years ago

Your method seems to be how I would do that myself, e.g. two losses and two respective gradient functions for the generator and the discriminator. It looks very clean and readable to me.

Now when it comes to GANs, better start by replicating an existing working one because it's very difficult to get one written from scratch to converge.

bilal2vec commented 4 years ago

Thanks,

I'm not sure how objax handles this but could the error in setting training=true for the generator in d_loss be that GradValues wants to be calculate the gradients of the generator as well or could updating the generator's running mean and stddev mess with the jit-ed generator() call in g_loss()?

david-berthelot commented 4 years ago

It could be due to this line:

d_train_op = objax.Jit(d_train_op, d_gv.vars() + d_opt.vars())
# it should probably be
d_train_op = objax.Jit(d_train_op, d_gv.vars() + d_opt.vars() + generator.vars())
# since you use the generator to compute in d_train_op

Generally speaking, for Jit, you could pass all the vars since it won't matter if they are unused while if you omit some vars they will be treated as constants. In this particular case, when training is true, batch norm statistics are updated, and since you didn't pass generator.vars() they were treated as constants. So in a nutshell, the code is writing to a constant which leads to an error.

bilal2vec commented 4 years ago

Thanks for all the help, there still seems to be a tracer error that jax is throwing because of the batchnorm layers in both networks. I guess i'll switch back to using flax or haiku for a while until i can take another look at this :)