Closed patrick-kidger closed 5 months ago
This is probably my bug, I'll look into it. I am pretty sure it's happening because the tracer is being converted into an argument instead of embedding it into consts
in ClosedJaxpr since it is being closed over in f
This is because before make_jaxpr
embedded Tracer
into consts which was wrong. I changed to unify the codepath of make_jaxpr and jit(f).specialize(*args).jaxpr
. The behavior is correct but we should fix this bug. Thanks for the concise repro!
Actually thinking more, tracers should be passed as argument so if you change g
above and pass tracer as the argument, the error should be gone.
So maybe the fix should be in diffrax?
How should that be done?
g
is a user-provided function. I don't think there is any way to extract the closed-over tracer. Previously this was in jax.make_jaxpr(g).consts
, but that is now an empty list.
Now it's in jaxpr.invars
. Would it be enough if I gave you a count as to how many closed over tracers are there and how many real args exist and then you can make things work in diffrax?
Note that we have this assert that disabled but we want to enable it: https://github.com/google/jax/blob/f768cb74b94ab36587a8930be8afe8a34460ca6b/jax/_src/core.py#L208
Hmm, I think I'm only see the Var
objects:
jaxpr.jaxpr.invars # [Var(id=4790894848):int32[], Var(id=4790895168):int32[]]
but not any way of grabbing some_tracer
out of jaxpr
.
To approach this a different way: in <=0.4.26, it is the case that for all functions g
, the following code will work:
jaxpr = jax.make_jaxpr(g)(*args)
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
Which I think is quite a nice invariant, actually! Now that this is no longer the case, what is the equivalent invariant for >=0.4.27? :)
jaxpr.jaxpr.invars
The first Var
is the closed over tracer. So if I can give you that count of how many tracers are there before the actual args start, would that be enough?
I don't think so. At least right now, I don't see a way to grab the tracer out of the Var
:
(Pdb) jaxpr.jaxpr.invars[0]
Var(id=4790894848):int32[]
(Pdb) dir(jaxpr.jaxpr.invars[0])
['__annotations__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', 'aval', 'count', 'suffix']
(Pdb) jaxpr.jaxpr.invars[0].aval
ShapedArray(int32[])
(Pdb) jaxpr.jaxpr.invars[0].count
14
(Pdb) jaxpr.jaxpr.invars[0].suffix
''
# no tracers here!
What am I missing?
Sorry, I meant that the first Var is the tracer (which would have been in consts before). There is no indication of that being a tracer right now, unless I give you more information from JAX (which is what my question was) and if that would be enough for you to fix diffrax
Before:
In [1]: import jax
...: import jax.numpy as jnp
...:
...: def run(some_tracer):
...: def f(x, y):
...: return x + y
...:
...: g = lambda x: f(x, some_tracer)
...: jaxpr = jax.make_jaxpr(g)(1)
...: print(jaxpr.in_avals, jaxpr.consts)
...: jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)
...:
...: jax.vmap(run)(jnp.arange(2))
[ShapedArray(int32[], weak_type=True)] [Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
val = Array([0, 1], dtype=int32)
batch_dim = 0]
Now:
In [2]: import jax
...: import jax.numpy as jnp
...:
...: def run(some_tracer):
...: def f(x, y):
...: return x + y
...:
...: g = lambda x: f(x, some_tracer)
...: jaxpr = jax.make_jaxpr(g)(1)
...: print(jaxpr.in_avals, jaxpr.consts)
...: jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)
...:
...: jax.vmap(run)(jnp.arange(2))
[ShapedArray(int32[]), ShapedArray(int32[], weak_type=True)] []
Right! But I don't need to know which Var
s correspond to tracers, I need the actual Tracer
object itself. I don't think this is available anywhere any more.
Sorry, I think I'm realising where the confusion is coming from: I'm not calling make_jaxpr
to get just the jaxpr. My goal here is to take an arbitrary user-provided function, which may include closed-over tracers, and to perform closure conversion to get both a pure function (.jaxpr
) and all the closed-over values (.consts
). Both are needed for the subsequent eval_jaxpr
call. This closure-conversion is what is no longer possible at all.
(Note that jax.closure_convert
isn't suitable, at least right now -- it's branded as an AD-specific tool, and as such only closure-converts with respect to floating point arrays. In general we actually need more than that, e.g. vmap'd integer arrays. IIRC this was deliberately removed from jax.closure_convert
a couple of years ago for some performance-related reason. I think Matt might recall?)
Taking a step back from what's needed to solve this one particular bug in the short term: in general, abstract evaluation can return four things: the jaxpr, the output pytree/avals, the effects, and the closed-over values. Right now, these pieces are incompletely mix-and-match'd across the public API. The state of affairs is:
| Jaxpr | Output Pytree/Avals | Effects | Closed-Over Values
----------------------|-------|---------------------|---------|--------------------
`jax.make_jaxpr` | x | x | | (x)
`jax.eval_shape` | | x | |
`jax.closure_convert` | ~ | | | ~
Where an x
means it gives you everything and a ~
means you only get some of it, and (x)
indicates that this is the capability we lost with the 0.4.27 release.
Speculating: perhaps this is a state of affairs that could be simplified?
(Side note, credit to Gemini for kindly making this table for me ^^ )
I need the actual Tracer object itself
Why do you need the actual tracer object? Embedding tracers into consts was a mistake to begin with. That should have never happened and this assert needs to be enabled: https://github.com/google/jax/blob/f768cb74b94ab36587a8930be8afe8a34460ca6b/jax/_src/core.py#L208 which would make it so.
Speculating: perhaps this is a state of affairs that could be simplified?
Eventually, we are going to merge make_jaxpr and eval_shape into jit(f).specialize(*args).jaxpr | .out_shapes
. Currently both eval_shape and jaxpr are available on jitted
functions.
https://github.com/google/jax/pull/21140 should roll it back. We are going to do another release tomorrow, so this should be fixed in 0.4.28
That said I think we're still trying to figure out the path forward here, this rollback is mostly to unbreak you.
Eventually, we are going to merge make_jaxpr and eval_shape into
jit(f).specialize(*args).jaxpr | .out_shapes
Ah, interesting! I like the unifying of this. Is this something where we'll be able to grab the out_shapes
without recording the jaxpr
? (Which IIUC is the advantage of eval_shape
over make_jaxpr(..., return_shape=True)
today?)
Why do you need the actual tracer object? Embedding tracers into consts was a mistake to begin with.
I have no preference on whether tracers are placed in jaxprs at all. E.g. I would be equally happy to get them via something looking likejaxpr, closed_over_values = jax.make_jaxpr(g, closed_over_values=True)
.
The actual tracer object is required to perform closure conversion prior to crossing the boundaries of higher order primitives, custom AD etc.
That said I think we're still trying to figure out the path forward here, this rollback is mostly to unbreak you.
Thank you! I really appreciate this.
I think if you're looking for a concrete suggestion on a path forward from me:
short term: adjust jax.closure_convert
to grab all constants, by having this unconditionally returning True
:
https://github.com/google/jax/blob/1e88e2f86298e85269e15ea24e797a0f56386ad8/jax/_src/custom_derivatives.py#L1122
IIUC the original rationale for this filtering was performance, so that custom_vjp
wouldn't need to compute unnecessary cotangents (#6415). However for that purpose we now have symbolic zeros available in custom AD, so this should no longer be a concern.
Then I can rewrite my code to perform closure-conversion via this function instead.
long term: have the jit(f).specialize
interface also provide effects and closed-over values. Then many other things (including jax.closure_convert
!) could be implemented through this.
What if we expose jex.trace_to_jaxpr
and you can use that to get the behavior you want and we can roll forward with the new behavior where tracers won't be in consts?
cc @froystig @mattjj
We just released jax 0.4.28, which has the rollback.
What if we expose
jex.trace_to_jaxpr
and you can use that to get the behavior you want and we can roll forward with the new behavior where tracers won't be in consts?
Do you mean jax.interpreters.partial_eval.trace_to_jaxpr
? This acts on a lot of JAX-internal types though. (WrappedFun
, PartialVal
, Value
.)
If you like, I could write a quick PR adding a flag to jax.closure_convert
that hoists integer-dtyped tracers as well?
We just released jax 0.4.28, which has the rollback.
Thank you! I appreciate it.
Rolling forward with https://github.com/google/jax/pull/21734 and fixed the reported error too. So we should be good to go without any changes on your side.
Wonderful stuff! Thank you so much :)
Haha.. no problem.
Description
produces:
System info (python version, jaxlib version, accelerator, etc.)
JAX 0.4.27