Closed Alan-Chen99 closed 2 days ago
Thanks for raising this!
Actually, I'm not sure if we should call this a 'bug', since this is an internal API. Is there an issue you're seeing in a public API?
This is one of the main reasons partial_eval_jaxpr_stateful
exists. The two may be merged someday, but for the moment this is essentially expected behavior for trace_to_jaxpr_nounits
.
Can you say more about what problem you're trying to solve?
Here's how you would call it, but this is a very internal API so beware:
import jax
from jax.experimental import checkify
from jax._src.interpreters import partial_eval as pe
def fn(x):
checkify.check(x < 10, "checkify")
jax.debug.print("callback: {}", x)
return x
jaxpr = jax.make_jaxpr(fn)(1)
jaxpr_known, jaxpr_unkown, out_unknowns, *_ = (
pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, in_unknowns=[True],
in_inst=[True], ensure_out_unknowns=False,
ensure_out_inst=False, saveable=lambda *_,
**__: True)
)
print(jaxpr_known, jaxpr_unkown, out_unknowns)
Thank you!
while im not using a public api, im during something similar to jax.linearize, which also have this problem
import jax
from jax import Array
from jax import numpy as jnp
from jax._src import checkify, effects
# effects.custom_derivatives_allowed_effects.add_type(checkify.ErrorEffect)
@jax.custom_jvp
def fn(x: Array):
return x
@fn.defjvp
def testfn_jvp(primals: tuple[Array, ...], tangents: tuple[Array, ...]):
(x,) = primals
(tg,) = tangents
jax.debug.print("callback: {}", tg)
# checkify.check(tg < 0, "invalid tangent {}", tg)
return x, tg
# jvp works
# jax.jvp(fn, [jnp.array(1.0)], [jnp.array(2.0)])
val, tang = jax.linearize(fn, jnp.array(1.0))
tang(jnp.array(2.0))
I think we should mark this as resolved, since it's about internal APIs. But I hope this suggestion worked for you.
Description
prints
I expected jaxpr_unkown to contain the effects.
possible workaround: wrap unknown args using a dynamic=False jaxpr trace. I also had to change how pjit_p is handled.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.28 jaxlib: 0.4.28 numpy: 1.26.4 python: 3.12.2 (main, Feb 6 2024, 20:19:44) [GCC 13.2.0] jax.devices (1 total, 1 local): [cuda(id=0)] process_count: 1