Closed bilal2vec closed 4 years ago
Your method seems to be how I would do that myself, e.g. two losses and two respective gradient functions for the generator and the discriminator. It looks very clean and readable to me.
Now when it comes to GANs, better start by replicating an existing working one because it's very difficult to get one written from scratch to converge.
Thanks,
I'm not sure how objax handles this but could the error in setting training=true for the generator in d_loss be that GradValues wants to be calculate the gradients of the generator as well or could updating the generator's running mean and stddev mess with the jit-ed generator() call in g_loss()?
It could be due to this line:
d_train_op = objax.Jit(d_train_op, d_gv.vars() + d_opt.vars())
# it should probably be
d_train_op = objax.Jit(d_train_op, d_gv.vars() + d_opt.vars() + generator.vars())
# since you use the generator to compute in d_train_op
Generally speaking, for Jit
, you could pass all the vars since it won't matter if they are unused while if you omit some vars they will be treated as constants. In this particular case, when training is true, batch norm statistics are updated, and since you didn't pass generator.vars()
they were treated as constants. So in a nutshell, the code is writing to a constant which leads to an error.
Thanks for all the help, there still seems to be a tracer error that jax is throwing because of the batchnorm layers in both networks. I guess i'll switch back to using flax or haiku for a while until i can take another look at this :)
Hi,
I'm trying to get a working dcgan implementation in objax. Since the discriminator and generator need to be optimized separately, I'm taking the gradients wrt to both modules in two different functions like this:
Would this the preferred way of doing this, or is there a way of returning two values (the loss for both the generator and discriminator) in one function and then computing the gradients of the generator and discriminator loss separately (e.g.
d_loss
wrtdiscriminator.vars()
)?The code runs, and both parts of the gan seem to train and their losses go down, but the discriminator's loss quickly converges to 0 which I'm guessing is caused by having to comment out the lines above to set
training=False
(which would prevent batchnorm from using the current batch's stats and might be the cause of the discriminator converging so quickly) since it causes aIndexError: tuple index out of range
error.colab notebook: https://colab.research.google.com/drive/1WTBKHqZWAg-TpXJmZWOn_7VCVOZsDP2F?usp=sharing