Closed bilal2vec closed 4 years ago
It's probably a side-effect (in your code or objax) stashing a JAX tracer. I requested access to the notebook!
sorry about that, I should've checked to see if the link worked. https://colab.research.google.com/drive/1gG0naJz_JbFHQwNxL9jTKeifwzVne6KE?usp=sharing
I got a different error when running the cells:
Could you double check that it's a repro of the bug you care about? Btw, minimizing it would increase our chance of making progress :)
This is the right place where to ask. I've got the same error as @mattjj when trying to reproduce it.
Your code you use both objax.functional.loss.cross_entropy_logits
and objax.functional.loss.cross_entropy_logits_sparse
. In both cases, the way the API is called is the cause.
In short, the nicest way to do it would be for g_loss
is:
return objax.functional.loss.cross_entropy_logits_sparse(discriminator(fake_img, training=False), 1).mean()
And for d_loss
, you could change the code accordingly (just use an int
for the label).
Generally speaking, you could still use a tensor if you really wanted to, just make sure it has proper dimensionality, like this:
return objax.functional.loss.cross_entropy_logits(discriminator(fake_img, training=False), jnp.ones(x.shape[0])).mean()
More generally, one way I found to debug is to add print
statement in various place to make sure my variables are okay, and you can find the places to look closely at the stack trace:
File "dev/dcgan.py", line 137, in g_loss
return objax.functional.loss.cross_entropy_logits(discriminator(fake_img, training=False), jnp.ones_like(x)).mean()
Hope it helps!
Thank you! the code runs now!
Hi, I'm working on implementing dcgan in objax and I'm running into this error:
It looks like this is being triggered by the running mean and stddev avgs of batchnorm layer in the discriminator during the generator update step.
Here's a colab link to a colab notebook to reproduce this: https://colab.research.google.com/drive/1gG0naJz_JbFHQwNxL9jTKeifwzVne6KE?usp=sharing
Also, I'm not sure if this is a good place to look for help on this bug. Would posting this in the Jax discussions page be more appropriate?
Thanks,
Bilal