PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
123 stars 27 forks source link

New `_sin_lowering` and `_cos_lowering` from JAX fail with dynamic shapes #972

Open tzunghanjuang opened 1 month ago

tzunghanjuang commented 1 month ago

Issue description

After updating jax and mlir dependency chain to v0.4.28 (PR#931), jax introduces new _sin_lowering and _cos_lowering with fails with dynamic shapes.

In the following code from jax._src.lax.lax, mlir.lower_fun triggers the error. To get rid of this, we temporarily patch these lowering with old-version function (_nary_lower_hlo).

def _sin_lowering(ctx, x):
  if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
    sine = mlir.lower_fun(_sin_complex, multiple_results=False)
    return sine(ctx, x)
  return _nary_lower_hlo(hlo.sine, ctx, x)

Relevent Jax PR:

https://github.com/google/jax/commit/6d8b3e4cff97d966e56670e70957334885439b76

Source code and tracebacks

Example: https://github.com/PennyLaneAI/catalyst/blob/5fa4b21922ab1e7beb8f83cd1a8daf4b0c298c95/frontend/test/pytest/test_jax_dynamic_api.py#L140-L157

Trace:

FAILED    [ 50%]
frontend/test/pytest/test_jax_dynamic_api.py:139 (test_classical_tracing_unary_ops[sin])
>   assert_array_and_dtype_equal(f(shape), op(jnp.ones(shape, dtype)))

test_jax_dynamic_api.py:157: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/pennylane/logging/decorators.py:61: in wrapper_entry
    return func(*args, **kwargs)
../../catalyst/jit.py:454: in __call__
    requires_promotion = self.jit_compile(args)
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/pennylane/logging/decorators.py:61: in wrapper_entry
    return func(*args, **kwargs)
../../catalyst/jit.py:525: in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args)
../../catalyst/debug/instruments.py:143: in wrapper
    return fn(*args, **kwargs)
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/pennylane/logging/decorators.py:61: in wrapper_entry
    return func(*args, **kwargs)
../../catalyst/jit.py:587: in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/pennylane/logging/decorators.py:61: in wrapper_entry
    return func(*args, **kwargs)
../../catalyst/jax_tracer.py:531: in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
../../catalyst/jax_extras/tracing.py:555: in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

>   return op(jnp.ones(s, dtype))
E   jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: Shapes must be 1D sequences of integer scalars, got (Var(id=133767151424768):int32[], Var(id=133767151424832):int32[])
E   
E   The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
E   
E   --------------------

test_jax_dynamic_api.py:155: JaxStackTraceBeforeTransformation

The above exception was the direct cause of the following exception:

op = <PjitFunction of <function jax.numpy.sin at 0x79a9267618a0>>

    @pytest.mark.parametrize(
        "op",
        [
            jnp.sin,
            jnp.abs,
        ],
    )
    def test_classical_tracing_unary_ops(op):
        """Test that tensor primitives work with basic unary operations"""

        shape = (3, 4)
        dtype = complex

        @qjit
        def f(s):
            return op(jnp.ones(s, dtype))

>       assert_array_and_dtype_equal(f(shape), op(jnp.ones(shape, dtype)))

test_jax_dynamic_api.py:157: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/pennylane/logging/decorators.py:61: in wrapper_entry
    return func(*args, **kwargs)
../../catalyst/jit.py:454: in __call__
    requires_promotion = self.jit_compile(args)
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/pennylane/logging/decorators.py:61: in wrapper_entry
    return func(*args, **kwargs)
../../catalyst/jit.py:527: in jit_compile
    self.mlir_module, self.mlir = self.generate_ir()
../../catalyst/debug/instruments.py:143: in wrapper
    return fn(*args, **kwargs)
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/pennylane/logging/decorators.py:61: in wrapper_entry
    return func(*args, **kwargs)
../../catalyst/jit.py:602: in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/pennylane/logging/decorators.py:61: in wrapper_entry
    return func(*args, **kwargs)
../../catalyst/jax_tracer.py:558: in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/pennylane/logging/decorators.py:61: in wrapper_entry
    return func(*args, **kwargs)
../../catalyst/jax_extras/lowering.py:72: in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/pennylane/logging/decorators.py:61: in wrapper_entry
    return func(*args, **kwargs)
../../catalyst/jax_extras/lowering.py:140: in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1438: in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1622: in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1730: in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/lax/lax.py:1886: in _sin_lowering
    return sine(ctx, x)
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1817: in f_lowered
    out, tokens = jaxpr_subcomp(
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1622: in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1730: in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/lax/lax.py:2379: in _compare_lower_hlo
    x, y = mlir.multi_broadcast_in_dim(ctx, (x, y), avals_in, aval_out.shape)
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1944: in multi_broadcast_in_dim
    core.ShapedArray(out_shape, op_aval.dtype),  # type: ignore
../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/core.py:1685: in __init__
    self.shape = canonicalize_shape(shape)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

shape = (Var(id=133767151424768):int32[], Var(id=133767151424832):int32[])
context = ''

    def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
      """Canonicalizes and checks for errors in a user-provided shape value.

      Args:
        shape: a Python value that represents a shape.

      Returns:
        A tuple of canonical dimension values.
      """
      try:
        return tuple(unsafe_map(_canonicalize_dimension, shape))
      except TypeError:
        pass
>     raise _invalid_shape_error(shape, context)
E     TypeError: Shapes must be 1D sequences of integer scalars, got (Var(id=133767151424768):int32[], Var(id=133767151424832):int32[])

../../../../../.conda/envs/xanadu-update/lib/python3.12/site-packages/jax/_src/core.py:1647: TypeError
PASSED [100%]
Process finished with exit code 1
josh146 commented 1 month ago

Is this something that can only be fixed upstream in JAX?

tzunghanjuang commented 1 month ago

The new lowerings assert that the shape information must be static. If we want to pass arguments without static_argnums, the error will be triggered. To allow dynamic shapes, we have to add a case to the new lowerings so that the old _nary_lower_hlo function (which allows dynamic shapes) can be used. So upstream is still required.