Open erick-xanadu opened 5 days ago
The following patch appears to solve the issue, but I don't like it:
diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py
index eeb7334f..8cfdf098 100644
--- a/frontend/catalyst/jax_extras/tracing.py
+++ b/frontend/catalyst/jax_extras/tracing.py
@@ -504,11 +504,8 @@ def make_jaxpr2(
)
register_lowering(gather2_p, _gather_lower)
- primitive_batchers2 = jax._src.interpreters.batching.primitive_batchers.copy()
- for primitive in jax._src.interpreters.batching.primitive_batchers.keys():
- if primitive.name == "gather":
- gather_batching_rule = jax._src.interpreters.batching.primitive_batchers[primitive]
- primitive_batchers2[gather2_p] = gather_batching_rule
+ from jax._src.lax.slicing import _gather_batching_rule
+ jax._src.interpreters.batching.primitive_batchers[gather2_p] = _gather_batching_rule
@wraps(fun)
def make_jaxpr_f(*args, **kwargs):
@@ -516,7 +513,6 @@ def make_jaxpr2(
with Patcher(
(jax._src.interpreters.partial_eval, "get_aval", get_aval2),
(jax._src.lax.slicing, "gather_p", gather2_p),
- (jax._src.interpreters.batching, "primitive_batchers", primitive_batchers2),
), ExitStack():
f = wrap_init(fun)
if static_argnums:
The following patch appears to solve the issue, but I don't like it:
diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py index eeb7334f..8cfdf098 100644 --- a/frontend/catalyst/jax_extras/tracing.py +++ b/frontend/catalyst/jax_extras/tracing.py @@ -504,11 +504,8 @@ def make_jaxpr2( ) register_lowering(gather2_p, _gather_lower) - primitive_batchers2 = jax._src.interpreters.batching.primitive_batchers.copy() - for primitive in jax._src.interpreters.batching.primitive_batchers.keys(): - if primitive.name == "gather": - gather_batching_rule = jax._src.interpreters.batching.primitive_batchers[primitive] - primitive_batchers2[gather2_p] = gather_batching_rule + from jax._src.lax.slicing import _gather_batching_rule + jax._src.interpreters.batching.primitive_batchers[gather2_p] = _gather_batching_rule @wraps(fun) def make_jaxpr_f(*args, **kwargs): @@ -516,7 +513,6 @@ def make_jaxpr2( with Patcher( (jax._src.interpreters.partial_eval, "get_aval", get_aval2), (jax._src.lax.slicing, "gather_p", gather2_p), - (jax._src.interpreters.batching, "primitive_batchers", primitive_batchers2), ), ExitStack(): f = wrap_init(fun) if static_argnums:
So the scope we patch the batching rules for is not sufficient for the example in this issue?
I think what is happening is that register_lowering
will modify a global variable in JAX. Not sure exactly under what conditions the path for E NotImplementedError: Batching rule for 'gather' not implemented
triggers, but this gather
is our gather
and it doesn't have a batching primitive. If there could be a unregister_lowering
or unregister_primitive
, maybe we could avoid it, but I haven't found it.
Running:
raises the following error:
But changing the order to
passes the test.
Possibly related to #733