google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.25k stars 2.68k forks source link

partial eval silently skips effects #21713

Closed Alan-Chen99 closed 2 days ago

Alan-Chen99 commented 1 month ago

Description

import jax
from jax.experimental import checkify
from jax.interpreters import partial_eval

def fn(x):
    checkify.check(x < 10, "checkify")
    jax.debug.print("callback: {}", x)
    return x

jaxpr = jax.make_jaxpr(fn)(1)
jaxpr_known, jaxpr_unkown, out_unknowns, out_avals = (
    partial_eval.partial_eval_jaxpr_nounits(jaxpr, unknowns=[True], instantiate=False)
)
print(jaxpr_known, jaxpr_unkown, out_unknowns, out_avals)

prints

{ lambda ; . let _:i32[] = select_n False 1 -1 in () } { lambda ; a:i32[]. let  in (a,) } [True] []

I expected jaxpr_unkown to contain the effects.

possible workaround: wrap unknown args using a dynamic=False jaxpr trace. I also had to change how pjit_p is handled.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.28 jaxlib: 0.4.28 numpy: 1.26.4 python: 3.12.2 (main, Feb 6 2024, 20:19:44) [GCC 13.2.0] jax.devices (1 total, 1 local): [cuda(id=0)] process_count: 1

mattjj commented 1 month ago

Thanks for raising this!

Actually, I'm not sure if we should call this a 'bug', since this is an internal API. Is there an issue you're seeing in a public API?

This is one of the main reasons partial_eval_jaxpr_stateful exists. The two may be merged someday, but for the moment this is essentially expected behavior for trace_to_jaxpr_nounits.

Can you say more about what problem you're trying to solve?

mattjj commented 1 month ago

Here's how you would call it, but this is a very internal API so beware:

import jax
from jax.experimental import checkify
from jax._src.interpreters import partial_eval as pe

def fn(x):
    checkify.check(x < 10, "checkify")
    jax.debug.print("callback: {}", x)
    return x

jaxpr = jax.make_jaxpr(fn)(1)
jaxpr_known, jaxpr_unkown, out_unknowns, *_ = (
    pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, in_unknowns=[True],
                                   in_inst=[True], ensure_out_unknowns=False,
                                   ensure_out_inst=False, saveable=lambda *_,
                                   **__: True)
)
print(jaxpr_known, jaxpr_unkown, out_unknowns)
Alan-Chen99 commented 1 month ago

Thank you!

while im not using a public api, im during something similar to jax.linearize, which also have this problem

import jax
from jax import Array
from jax import numpy as jnp
from jax._src import checkify, effects

# effects.custom_derivatives_allowed_effects.add_type(checkify.ErrorEffect)

@jax.custom_jvp
def fn(x: Array):
    return x

@fn.defjvp
def testfn_jvp(primals: tuple[Array, ...], tangents: tuple[Array, ...]):

    (x,) = primals
    (tg,) = tangents

    jax.debug.print("callback: {}", tg)
    # checkify.check(tg < 0, "invalid tangent {}", tg)

    return x, tg

# jvp works
# jax.jvp(fn, [jnp.array(1.0)], [jnp.array(2.0)])

val, tang = jax.linearize(fn, jnp.array(1.0))
tang(jnp.array(2.0))
mattjj commented 2 days ago

I think we should mark this as resolved, since it's about internal APIs. But I hope this suggestion worked for you.