google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.79k stars 2.72k forks source link

custom_jvp leaks tracers if they're marked as nondiff_argnum #23065

Open jakevdp opened 4 weeks ago

jakevdp commented 4 weeks ago
import jax
import jax.numpy as jnp
from functools import partial

jax.config.update("jax_check_tracer_leaks", True)

@partial(jax.custom_jvp, nondiff_argnums=(1,))
def f(x, indices):
  return x[indices]

@f.defjvp
def f_jvp(indices, primals, tangents):
  x, = primals
  x_dot, = tangents
  return f(x, indices), x_dot[indices]

x = jnp.arange(10.0)
indices = jnp.array([1, 3, 5])

jax.jit(jax.jacobian(f))(x, indices)
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
[<ipython-input-7-6ccf5735404c>](https://localhost:8080/#) in <cell line: 20>()
     18 indices = jnp.array([1, 3, 5])
     19 
---> 20 jax.jit(jax.jacobian(f))(x, indices)

    [... skipping hidden 17 frame]

1 frames
[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in new_main(trace_type, dynamic, **payload)
   1201     if t() is not None:
   1202       leaked_tracers = maybe_find_leaked_tracers(t())
-> 1203       if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
   1204 
   1205 @contextmanager

Exception: Leaked trace MainTrace(2,JaxprTrace). Leaked tracer(s):

Traced<ShapedArray(float32[10]):JaxprTrace(level=2/0)>
<JaxprTracer 132784156284128> is referred to by <list 132784156610688>[0]
<list 132784156610688> is referred to by <frame 132784230502464>
<frame 132784230502464> is referred to by <frame 97010772184064>
<frame 97010772184064> is referred to by <generator 132784155884816>

Traced<ShapedArray(float32[3]):JaxprTrace(level=2/0)>
<JaxprTracer 132784156291008> is referred to by <list 132784155950080>[0]
<list 132784155950080> is referred to by <frame 132784230502464>
<frame 132784230502464> is referred to by <list 132784232647808>[3]
<list 132784232647808> is referred to by <FramesList 132785276236272>._frames
<FramesList 132785276236272> is referred to by <frame 97010772184064>
<frame 97010772184064> is referred to by <generator 132784155884816>

Traced<ShapedArray(float32[3]):JaxprTrace(level=2/0)>
<JaxprTracer 132784156290608> is referred to by <tuple 132784231897008>[0]
<tuple 132784231897008> is referred to by <JaxprEqnRecipe 132784156756448>[1]
<JaxprEqnRecipe 132784156756448> is referred to by <JaxprTracer 132784156291008>

Traced<ShapedArray(int32[3,1]):JaxprTrace(level=2/0)>
<JaxprTracer 132784156287888> is referred to by <tuple 132784155944192>[1]
<tuple 132784155944192> is referred to by <JaxprEqnRecipe 132784156756560>[1]
<JaxprEqnRecipe 132784156756560> is referred to by <JaxprTracer 132784156290608>
<JaxprTracer 132784156290608> is referred to by <tuple 132784231897008>[0]
<tuple 132784231897008> is referred to by <JaxprEqnRecipe 132784156756448>[1]
<JaxprEqnRecipe 132784156756448> is referred to by <JaxprTracer 132784156291008>

The problem is here that indices is a traced array, but is marked as a nondiff_argnum. This is user error, but we should fail with a more informative error.

dfm commented 4 weeks ago

I think this is actually a bug (although @mattjj and @froystig will know better, of course!). If we update f_jvp to:

@f.defjvp
def f_jvp(indices, primals, tangents):
  x, = primals
  x_dot, = tangents
  return x[indicies], x_dot[indices]

everything works as it should without any leaked tracers. I haven't had a chance to dig into why this is, but there's a lot of subtlety around handling these recursive patterns. Regardless, I don't think that we do need to require that "nondiff" args be "static".

dfm commented 4 weeks ago

OK maybe I take that back! It's clear from this error message:

https://github.com/google/jax/blob/82d3cfb3c6f88321f0b29b4cc41134a464de82c2/jax/_src/custom_derivatives.py#L656-L662

that custom VJP at least has this requirement that nondiff args cannot be tracers. I think it also follows that they can't be tracers because they're handled by baking them into the functions, rather than binding them. I guess from my example above, there are some cases where things might still work, but they shouldn't be expected to.

Perhaps it would be reasonable to add the tracer check from custom_vjp to custom_jvp?

Edited to add: These are two relevant tests:

https://github.com/google/jax/blob/82d3cfb3c6f88321f0b29b4cc41134a464de82c2/tests/api_test.py#L7262-L7303

and it looks like we currently expect the custom_jvp to work with some tracers (BatchTracer) but not others. Reading the PR adding that comment https://github.com/google/jax/pull/14263, I think perhaps it makes sense to add that check to custom_jvp, but perhaps there are people depending on the current more relaxed behavior.