jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.51k stars 2.8k forks source link

Crash in `eval_jaxpr` with 0.4.27 #21116

Closed patrick-kidger closed 5 months ago

patrick-kidger commented 6 months ago

Description

import jax
import jax.numpy as jnp

def run(some_tracer):
    def f(x, y):
        return x + y

    g = lambda x: f(x, some_tracer)
    jaxpr = jax.make_jaxpr(g)(1)
    jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)

jax.vmap(run)(jnp.arange(2))

produces:

Traceback (most recent call last):
  File ".../file.py", line 13, in <module>
    jax.vmap(run)(jnp.arange(2))
  File ".../file.py", line 11, in run
    jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)
ValueError: safe_map() argument 2 is shorter than argument 1

System info (python version, jaxlib version, accelerator, etc.)

JAX 0.4.27

yashk2810 commented 6 months ago

This is probably my bug, I'll look into it. I am pretty sure it's happening because the tracer is being converted into an argument instead of embedding it into consts in ClosedJaxpr since it is being closed over in f

yashk2810 commented 6 months ago

This is because before make_jaxpr embedded Tracer into consts which was wrong. I changed to unify the codepath of make_jaxpr and jit(f).specialize(*args).jaxpr. The behavior is correct but we should fix this bug. Thanks for the concise repro!

yashk2810 commented 6 months ago

Actually thinking more, tracers should be passed as argument so if you change g above and pass tracer as the argument, the error should be gone.

So maybe the fix should be in diffrax?

patrick-kidger commented 6 months ago

How should that be done?

g is a user-provided function. I don't think there is any way to extract the closed-over tracer. Previously this was in jax.make_jaxpr(g).consts, but that is now an empty list.

yashk2810 commented 6 months ago

Now it's in jaxpr.invars. Would it be enough if I gave you a count as to how many closed over tracers are there and how many real args exist and then you can make things work in diffrax?

Note that we have this assert that disabled but we want to enable it: https://github.com/google/jax/blob/f768cb74b94ab36587a8930be8afe8a34460ca6b/jax/_src/core.py#L208

patrick-kidger commented 6 months ago

Hmm, I think I'm only see the Var objects:

jaxpr.jaxpr.invars  # [Var(id=4790894848):int32[], Var(id=4790895168):int32[]]

but not any way of grabbing some_tracer out of jaxpr.


To approach this a different way: in <=0.4.26, it is the case that for all functions g, the following code will work:

jaxpr = jax.make_jaxpr(g)(*args)
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)

Which I think is quite a nice invariant, actually! Now that this is no longer the case, what is the equivalent invariant for >=0.4.27? :)

yashk2810 commented 6 months ago

jaxpr.jaxpr.invars

The first Var is the closed over tracer. So if I can give you that count of how many tracers are there before the actual args start, would that be enough?

patrick-kidger commented 6 months ago

I don't think so. At least right now, I don't see a way to grab the tracer out of the Var:

(Pdb) jaxpr.jaxpr.invars[0]
Var(id=4790894848):int32[]
(Pdb) dir(jaxpr.jaxpr.invars[0])
['__annotations__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', 'aval', 'count', 'suffix']
(Pdb) jaxpr.jaxpr.invars[0].aval                                                                                                                   
ShapedArray(int32[])
(Pdb) jaxpr.jaxpr.invars[0].count
14
(Pdb) jaxpr.jaxpr.invars[0].suffix
''
# no tracers here!

What am I missing?

yashk2810 commented 6 months ago

Sorry, I meant that the first Var is the tracer (which would have been in consts before). There is no indication of that being a tracer right now, unless I give you more information from JAX (which is what my question was) and if that would be enough for you to fix diffrax

Before:

In [1]: import jax
   ...: import jax.numpy as jnp
   ...:
   ...: def run(some_tracer):
   ...:     def f(x, y):
   ...:         return x + y
   ...:
   ...:     g = lambda x: f(x, some_tracer)
   ...:     jaxpr = jax.make_jaxpr(g)(1)
   ...:     print(jaxpr.in_avals, jaxpr.consts)
   ...:     jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)
   ...:
   ...: jax.vmap(run)(jnp.arange(2))

[ShapedArray(int32[], weak_type=True)] [Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
  val = Array([0, 1], dtype=int32)
  batch_dim = 0]

Now:

In [2]: import jax
   ...: import jax.numpy as jnp
   ...:
   ...: def run(some_tracer):
   ...:     def f(x, y):
   ...:         return x + y
   ...:
   ...:     g = lambda x: f(x, some_tracer)
   ...:     jaxpr = jax.make_jaxpr(g)(1)
   ...:     print(jaxpr.in_avals, jaxpr.consts)
   ...:     jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)
   ...:
   ...: jax.vmap(run)(jnp.arange(2))
[ShapedArray(int32[]), ShapedArray(int32[], weak_type=True)] []
patrick-kidger commented 6 months ago

Right! But I don't need to know which Vars correspond to tracers, I need the actual Tracer object itself. I don't think this is available anywhere any more.

Sorry, I think I'm realising where the confusion is coming from: I'm not calling make_jaxpr to get just the jaxpr. My goal here is to take an arbitrary user-provided function, which may include closed-over tracers, and to perform closure conversion to get both a pure function (.jaxpr) and all the closed-over values (.consts). Both are needed for the subsequent eval_jaxpr call. This closure-conversion is what is no longer possible at all.

(Note that jax.closure_convert isn't suitable, at least right now -- it's branded as an AD-specific tool, and as such only closure-converts with respect to floating point arrays. In general we actually need more than that, e.g. vmap'd integer arrays. IIRC this was deliberately removed from jax.closure_convert a couple of years ago for some performance-related reason. I think Matt might recall?)

Taking a step back from what's needed to solve this one particular bug in the short term: in general, abstract evaluation can return four things: the jaxpr, the output pytree/avals, the effects, and the closed-over values. Right now, these pieces are incompletely mix-and-match'd across the public API. The state of affairs is:

                      | Jaxpr | Output Pytree/Avals | Effects | Closed-Over Values
----------------------|-------|---------------------|---------|--------------------
`jax.make_jaxpr`      |   x   |         x           |         |      (x)
`jax.eval_shape`      |       |         x           |         |
`jax.closure_convert` |   ~   |                     |         |       ~

Where an x means it gives you everything and a ~ means you only get some of it, and (x) indicates that this is the capability we lost with the 0.4.27 release. Speculating: perhaps this is a state of affairs that could be simplified?

(Side note, credit to Gemini for kindly making this table for me ^^ )

yashk2810 commented 6 months ago

I need the actual Tracer object itself

Why do you need the actual tracer object? Embedding tracers into consts was a mistake to begin with. That should have never happened and this assert needs to be enabled: https://github.com/google/jax/blob/f768cb74b94ab36587a8930be8afe8a34460ca6b/jax/_src/core.py#L208 which would make it so.

Speculating: perhaps this is a state of affairs that could be simplified?

Eventually, we are going to merge make_jaxpr and eval_shape into jit(f).specialize(*args).jaxpr | .out_shapes. Currently both eval_shape and jaxpr are available on jitted functions.

yashk2810 commented 6 months ago

https://github.com/google/jax/pull/21140 should roll it back. We are going to do another release tomorrow, so this should be fixed in 0.4.28

hawkinsp commented 6 months ago

That said I think we're still trying to figure out the path forward here, this rollback is mostly to unbreak you.

patrick-kidger commented 6 months ago

Eventually, we are going to merge make_jaxpr and eval_shape into jit(f).specialize(*args).jaxpr | .out_shapes

Ah, interesting! I like the unifying of this. Is this something where we'll be able to grab the out_shapes without recording the jaxpr? (Which IIUC is the advantage of eval_shape over make_jaxpr(..., return_shape=True) today?)

Why do you need the actual tracer object? Embedding tracers into consts was a mistake to begin with.

I have no preference on whether tracers are placed in jaxprs at all. E.g. I would be equally happy to get them via something looking likejaxpr, closed_over_values = jax.make_jaxpr(g, closed_over_values=True).

The actual tracer object is required to perform closure conversion prior to crossing the boundaries of higher order primitives, custom AD etc.

That said I think we're still trying to figure out the path forward here, this rollback is mostly to unbreak you.

Thank you! I really appreciate this.

I think if you're looking for a concrete suggestion on a path forward from me:

yashk2810 commented 6 months ago

What if we expose jex.trace_to_jaxpr and you can use that to get the behavior you want and we can roll forward with the new behavior where tracers won't be in consts?

cc @froystig @mattjj

hawkinsp commented 6 months ago

We just released jax 0.4.28, which has the rollback.

patrick-kidger commented 6 months ago

What if we expose jex.trace_to_jaxpr and you can use that to get the behavior you want and we can roll forward with the new behavior where tracers won't be in consts?

Do you mean jax.interpreters.partial_eval.trace_to_jaxpr? This acts on a lot of JAX-internal types though. (WrappedFun, PartialVal, Value.)

If you like, I could write a quick PR adding a flag to jax.closure_convert that hoists integer-dtyped tracers as well?

We just released jax 0.4.28, which has the rollback.

Thank you! I appreciate it.

yashk2810 commented 5 months ago

Rolling forward with https://github.com/google/jax/pull/21734 and fixed the reported error too. So we should be good to go without any changes on your side.

patrick-kidger commented 5 months ago

Wonderful stuff! Thank you so much :)

yashk2810 commented 5 months ago

Haha.. no problem.