PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
101 stars 26 forks source link

Bug with batching rule for `gather` not implemented #894

Open erick-xanadu opened 5 days ago

erick-xanadu commented 5 days ago

Running:

@pytest.mark.parametrize("arg", [jnp.array([[0.1, 0.2], [0.3, 0.4]])])
@pytest.mark.parametrize("order", ["good", "bad"])
def test_vjp_as_residual(arg, order):
    """See https://github.com/PennyLaneAI/catalyst/issues/852"""

    def jax_callback(fn, result_type):

        @pure_callback
        def callback_fn(*args) -> result_type:
            return fn(*args)

        @callback_fn.fwd
        def callback_fn_fwd(*args):
            ans, vjp_func = accelerate(lambda *x: jax.vjp(fn, *x))(*args)
            return ans, vjp_func

        @callback_fn.bwd
        def callback_fn_bwd(vjp_func, dy):
            return accelerate(vjp_func)(dy)

        return callback_fn

    @qml.qjit
    @jacobian
    def hypothesis(x):
        expm = jax_callback(jax.scipy.linalg.expm, jax.ShapeDtypeStruct((2, 2), jnp.float64))
        return expm(x)

    @jax.jacobian
    def ground_truth(x):
        return jax.scipy.linalg.expm(x)

    obs = hypothesis(arg)
    exp = ground_truth(arg)
    assert np.allclose(obs, exp)

raises the following error:

E     NotImplementedError: Batching rule for 'gather' not implemented

../env/lib/python3.10/site-packages/jax/_src/api.py:945: NotImplementedError
====================================================================================== warnings summary ======================================================================================
frontend/catalyst/jax_extras/tracing.py:92
  /home/ubuntu/code/catalyst/frontend/catalyst/jax_extras/tracing.py:92: DeprecationWarning: jax.linear_util.transformation_with_aux is deprecated. Use jax.extend.linear_util.transformation_with_aux instead.
    from jax.linear_util import transformation_with_aux, wrap_init

frontend/catalyst/jax_extras/tracing.py:92
  /home/ubuntu/code/catalyst/frontend/catalyst/jax_extras/tracing.py:92: DeprecationWarning: jax.linear_util.wrap_init is deprecated. Use jax.extend.linear_util.wrap_init instead.
    from jax.linear_util import transformation_with_aux, wrap_init

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================== short test summary info ===================================================================================
FAILED frontend/test/pytest/test_callback.py::test_vjp_as_residual[good-arg0] - NotImplementedError: Batching rule for 'gather' not implemented
FAILED frontend/test/pytest/test_callback.py::test_vjp_as_residual[bad-arg0] - NotImplementedError: Batching rule for 'gather' not implemented

But changing the order to

    exp = ground_truth(arg)
    obs = hypothesis(arg)

passes the test.

Possibly related to #733

erick-xanadu commented 5 days ago

The following patch appears to solve the issue, but I don't like it:

diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py
index eeb7334f..8cfdf098 100644
--- a/frontend/catalyst/jax_extras/tracing.py
+++ b/frontend/catalyst/jax_extras/tracing.py
@@ -504,11 +504,8 @@ def make_jaxpr2(
     )
     register_lowering(gather2_p, _gather_lower)

-    primitive_batchers2 = jax._src.interpreters.batching.primitive_batchers.copy()
-    for primitive in jax._src.interpreters.batching.primitive_batchers.keys():
-        if primitive.name == "gather":
-            gather_batching_rule = jax._src.interpreters.batching.primitive_batchers[primitive]
-            primitive_batchers2[gather2_p] = gather_batching_rule
+    from jax._src.lax.slicing import _gather_batching_rule
+    jax._src.interpreters.batching.primitive_batchers[gather2_p] = _gather_batching_rule

     @wraps(fun)
     def make_jaxpr_f(*args, **kwargs):
@@ -516,7 +513,6 @@ def make_jaxpr2(
         with Patcher(
             (jax._src.interpreters.partial_eval, "get_aval", get_aval2),
             (jax._src.lax.slicing, "gather_p", gather2_p),
-            (jax._src.interpreters.batching, "primitive_batchers", primitive_batchers2),
         ), ExitStack():
             f = wrap_init(fun)
             if static_argnums:
dime10 commented 5 days ago

The following patch appears to solve the issue, but I don't like it:

diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py
index eeb7334f..8cfdf098 100644
--- a/frontend/catalyst/jax_extras/tracing.py
+++ b/frontend/catalyst/jax_extras/tracing.py
@@ -504,11 +504,8 @@ def make_jaxpr2(
     )
     register_lowering(gather2_p, _gather_lower)

-    primitive_batchers2 = jax._src.interpreters.batching.primitive_batchers.copy()
-    for primitive in jax._src.interpreters.batching.primitive_batchers.keys():
-        if primitive.name == "gather":
-            gather_batching_rule = jax._src.interpreters.batching.primitive_batchers[primitive]
-            primitive_batchers2[gather2_p] = gather_batching_rule
+    from jax._src.lax.slicing import _gather_batching_rule
+    jax._src.interpreters.batching.primitive_batchers[gather2_p] = _gather_batching_rule

     @wraps(fun)
     def make_jaxpr_f(*args, **kwargs):
@@ -516,7 +513,6 @@ def make_jaxpr2(
         with Patcher(
             (jax._src.interpreters.partial_eval, "get_aval", get_aval2),
             (jax._src.lax.slicing, "gather_p", gather2_p),
-            (jax._src.interpreters.batching, "primitive_batchers", primitive_batchers2),
         ), ExitStack():
             f = wrap_init(fun)
             if static_argnums:

So the scope we patch the batching rules for is not sufficient for the example in this issue?

erick-xanadu commented 5 days ago

I think what is happening is that register_lowering will modify a global variable in JAX. Not sure exactly under what conditions the path for E NotImplementedError: Batching rule for 'gather' not implemented triggers, but this gather is our gather and it doesn't have a batching primitive. If there could be a unregister_lowering or unregister_primitive, maybe we could avoid it, but I haven't found it.