google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.59k stars 2.7k forks source link

UnexpectedTracerError in optics code #4714

Open benjaminpope opened 3 years ago

benjaminpope commented 3 years ago

Dear Jax team,

I am building an optics package morphine that uses Jax to do autodiff for telescope simulations.

I'm encountering a bug

UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function. The functions being transformed should not save traced values to global state. Details: Can't lift sublevels 2 to 1.

when I use morphine - here is a minimal example that reproduces this: https://github.com/benjaminpope/morphine/blob/stable/notebooks/morphine_jit_example.ipynb .

Basically this seems to happen when I try and execute something involving self.wavefront, which is an attribute of the BasicWavefront object that should just be a numpy array containing the phase and amplitude of the light going through the telescope. Jit seems to work just fine in producing a BasicWavefront object which has this as an attribute, but fails when you try and return it inside a function! What's happening?

All the best,

Ben

mattjj commented 3 years ago

This indicates that there's a side-effect happening, so that a JAX-transformed function isn't pure and as a result a Tracer object is being stashed in global state somewhere.

At the very least we should make this error message better so that it provides as much info as possible about where this side-effect is happening. I think we have the tools to do that now.

Hopefully we can also advise on how to fix morphine if there's an issue in there!

mattjj commented 3 years ago

To put a finer point on it: this indicates a bug in the usage of JAX, rather than in JAX itself, in that JAX requires all transformed functions to be pure (and in particular not to have side-effects on global state). While the issue presents itself as being about the tracing mechanism, really this indicates that the function can't be correctly transformed.

benjaminpope commented 3 years ago

Thanks for the help. So the attribute where it chokes is Wavefront.wavefront, which should just be an array. But the funny thing is if you just return the Wavefront object (in the code I sent: psf) inside the jitted function, and then outside the jit function you look at psf.wavefront, it has the correct result.

So is the Wavefront object psf doing some global state thing? I'm not sure it should be but I don't know how the Jax tracers work under the hood well enough to know.

(Incidentally, this is probably a neat thing to resolve: generating psf with jit vs without seems to be 100x faster!)

benjaminpope commented 3 years ago

Ok it turns out the whole point was I was returning objects - I didn't realize you could only jit to return arrays. D'oh!

mattjj commented 3 years ago

Well, you can in general have transformed functions return pytrees of arrays. If your custom objects act like containers of arrays, in that they're effectively isomorphic to tuples of arrays, you can register them as custom pytrees (i.e. custom container types) and all jax transformations will then work with them. See that link for more info.

I want to improve this error message, so let's leave the issue open until we do, or else I'll forget!

mattjj commented 3 years ago

(The text on that page I linked is a bit self-contradictory, in that it both says that a leaf is and is not a pytree. Seems like it should be debugged, but hopefully the intent is clear if read in its totality!)

mattjj commented 3 years ago

(Will try to remove the contradiction in the docs in #4739.)