PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
123 stars 27 forks source link

qjit fails when multiple tapes share the same loop object #1012

Open paul0403 opened 1 month ago

paul0403 commented 1 month ago
dev = qml.device("lightning.qubit", wires=2)

def my_quantum_transform(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.QuantumTape], Callable):
    tape1 = tape
    tape2 = qml.tape.QuantumTape(tape.operations, tape.measurements)
    def post_processing_fn(results):
        return results[0] + results[1]
    return [tape1, tape2], post_processing_fn

dispatched_transform = qml.transform(my_quantum_transform)

@qml.qnode(dev)
def circuit():
    @catalyst.for_loop(0, 1, 1)
    def loop0(_, yy):
        qml.RX(3.14, wires=0)
        return yy + 2
    loop0(0)
    return qml.expval(qml.X(0))

circuit = dispatched_transform(circuit)
circuit = qjit(circuit)
print("qjit results: ", circuit())

>>>
Traceback (most recent call last):
  File "/home/paul.wang/catalyst_new/catalyst/multi_tape.py", line 144, in <module>
    circuit = qjit(circuit)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 376, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 443, in __init__
    self.aot_compile()
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 481, in aot_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 606, in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 531, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 604, in fn_with_transform_named_sequence
    return self.user_function(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 584, in closure
    return QFunc.__call__(qnode, *args, **dict(params, **kwargs))
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/qfunc.py", line 165, in __call__
    res_flat = func_p.bind(flattened_fun, *args_flat, fn=self)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/qfunc.py", line 143, in _eval_quantum
    closed_jaxpr, out_type, out_tree, out_tree_exp = trace_quantum_function(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 1162, in trace_quantum_function
    qrp_out = trace_quantum_operations(tape, device, qreg_in, ctx, trace, mcm_config)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 655, in trace_quantum_operations
    qrp2 = op.trace_quantum(ctx, device, trace, qrp, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/api_extensions/control_flow.py", line 1209, in trace_quantum
    op.bind_overwrite_classical_tracers(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 460, in bind_overwrite_classical_tracers
    out_quantum_tracer = self.binder(*in_expanded_tracers, **kwargs)[-1]
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/tracing.py", line 969, in bind
    source_info = jax_current()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError: ({ lambda ; a:i64[] b:i64[] c:AbstractQreg() d:AbstractQreg(). let
    e:i64[] = add b 2
    f:AbstractQbit() = qextract c 0
    g:AbstractQbit() = qinst[
      adjoint=False
      ctrl_len=0
      op=RX
      params_len=1
      qubits_len=1
    ] f 3.14
    _:AbstractQreg() = qinsert c 0 g
    h:AbstractQbit() = qextract d 0
    i:AbstractQbit() = qinst[
      adjoint=False
      ctrl_len=0
      op=RX
      params_len=1
      qubits_len=1
    ] h 3.14
    j:AbstractQreg() = qinsert d 0 i
  in (e, j) }, ([<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x722b91bb37f0>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb0b70>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb25b0>]))

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/paul.wang/catalyst_new/catalyst/multi_tape.py", line 144, in <module>
    circuit = qjit(circuit)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 376, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 443, in __init__
    self.aot_compile()
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 486, in aot_compile
    self.mlir_module, self.mlir = self.generate_ir()
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 621, in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 558, in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/lowering.py", line 72, in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/lowering.py", line 140, in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_primitives.py", line 608, in _func_lowering
    func_op = _func_def_lowering(ctx.module_context, fn, call_jaxpr, name_stack=ctx.name_stack)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_primitives.py", line 566, in _func_def_lowering
    func_op = mlir.lower_jaxpr_to_fun(ctx, fn.__name__, call_jaxpr, tuple(), name_stack=name_stack)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_primitives.py", line 2022, in _for_loop_lowering
    out, _ = mlir.jaxpr_subcomp(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1574, in jaxpr_subcomp
    assert len(args) == len(jaxpr.invars), (jaxpr, args)
AssertionError: ({ lambda ; a:i64[] b:i64[] c:AbstractQreg() d:AbstractQreg(). let
    e:i64[] = add b 2
    f:AbstractQbit() = qextract c 0
    g:AbstractQbit() = qinst[
      adjoint=False
      ctrl_len=0
      op=RX
      params_len=1
      qubits_len=1
    ] f 3.14
    _:AbstractQreg() = qinsert c 0 g
    h:AbstractQbit() = qextract d 0
    i:AbstractQbit() = qinst[
      adjoint=False
      ctrl_len=0
      op=RX
      params_len=1
      qubits_len=1
    ] h 3.14
    j:AbstractQreg() = qinsert d 0 i
  in (e, j) }, ([<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x722b91bb37f0>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb0b70>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb25b0>]))
josh146 commented 3 weeks ago

Thanks @paul0403! I just wanted to check with @erick-xanadu, is this expected with our current support for quantum transforms?

If I recall correctly, we integrated PL tape transforms assuming that the QNode being transformed was a straight line program only (e.g., that no for loops or conditionals were present).

dime10 commented 3 weeks ago

Thanks @paul0403! I just wanted to check with @erick-xanadu, is this expected with our current support for quantum transforms?

If I recall correctly, we integrated PL tape transforms assuming that the QNode being transformed was a straight line program only (e.g., that no for loops or conditionals were present).

We already disallow MCMs with transforms that produce multiple tapes, but I think we are not restricted enough. Disallowing all hybrid ops for the moment should prevent the issue at least.

josh146 commented 3 weeks ago

So the fix here would simply be additional validation? Sounds good!

erick-xanadu commented 3 weeks ago

@josh146 yes, the assumption is that QNode being transformed was a straight line program only. I think there was some initial validation, but perhaps it did not cover all possible cases.