google / objax

Apache License 2.0
769 stars 77 forks source link

Tracer error when using a random variable #88

Closed bilal2vec closed 4 years ago

bilal2vec commented 4 years ago

Hi, I'm working on implementing dcgan in objax and I'm running into this error:

UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: Different traces at same level: Traced<ShapedArray(float32[1,128,1,1])>with<JVPTrace(level=2/1)>
  with primal = Traced<ShapedArray(float32[1,128,1,1])>with<DynamicJaxprTrace(level=0/1)>
       tangent = Traced<ShapedArray(float32[1,128,1,1]):JaxprTrace(level=1/1)>, JVPTrace(level=2/1).```

It looks like this is being triggered by the running mean and stddev avgs of batchnorm layer in the discriminator during the generator update step.

Here's a colab link to a colab notebook to reproduce this: https://colab.research.google.com/drive/1gG0naJz_JbFHQwNxL9jTKeifwzVne6KE?usp=sharing

Also, I'm not sure if this is a good place to look for help on this bug. Would posting this in the Jax discussions page be more appropriate?

Thanks,

Bilal

mattjj commented 4 years ago

It's probably a side-effect (in your code or objax) stashing a JAX tracer. I requested access to the notebook!

bilal2vec commented 4 years ago

sorry about that, I should've checked to see if the link worked. https://colab.research.google.com/drive/1gG0naJz_JbFHQwNxL9jTKeifwzVne6KE?usp=sharing

mattjj commented 4 years ago

I got a different error when running the cells:

image

Could you double check that it's a repro of the bug you care about? Btw, minimizing it would increase our chance of making progress :)

david-berthelot commented 4 years ago

This is the right place where to ask. I've got the same error as @mattjj when trying to reproduce it. Your code you use both objax.functional.loss.cross_entropy_logits and objax.functional.loss.cross_entropy_logits_sparse. In both cases, the way the API is called is the cause.

In short, the nicest way to do it would be for g_loss is:

return objax.functional.loss.cross_entropy_logits_sparse(discriminator(fake_img, training=False), 1).mean()

And for d_loss, you could change the code accordingly (just use an int for the label).

Generally speaking, you could still use a tensor if you really wanted to, just make sure it has proper dimensionality, like this:

return objax.functional.loss.cross_entropy_logits(discriminator(fake_img, training=False), jnp.ones(x.shape[0])).mean()

More generally, one way I found to debug is to add print statement in various place to make sure my variables are okay, and you can find the places to look closely at the stack trace:

  File "dev/dcgan.py", line 137, in g_loss
    return objax.functional.loss.cross_entropy_logits(discriminator(fake_img, training=False), jnp.ones_like(x)).mean()

Hope it helps!

bilal2vec commented 4 years ago

Thank you! the code runs now!