PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
142 stars 36 forks source link

Catalyst does not support QJIT-compiling a parameterized circuit with `qml.ctrl` #1266

Open joeycarter opened 4 weeks ago

joeycarter commented 4 weeks ago

We discovered this issue when attempting to QJIT-compile a circuit implementing Grover's algorithm.

Consider the following PennyLane program that applies the qml.ctrl function to a Z gate, which results in an error only when @qjit is applied:

import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit

NUM_QUBITS = 2

dev = qml.device("lightning.qubit", wires=NUM_QUBITS)

@qjit
@qml.qnode(dev)
def circuit(basis_state):
    wires = list(range(NUM_QUBITS))
    qml.ctrl(qml.Z(wires[-1]), control=wires[:-1], control_values=basis_state[:-1])
    return qml.state()

basis_state = jnp.array([0.0, 0.0])
state = circuit(basis_state)
Traceback (most recent call last):
...
  File ".../venv/lib/python3.12/site-packages/pennylane/ops/op_math/controlled.py", line 144, in ctrl
    return ops_loader.ctrl(op, control, control_values=control_values, work_wires=work_wires)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/quantum_operators.py", line 319, in ctrl
    return res() if isinstance(f, Operator) else res
           ^^^^^
  File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/quantum_operators.py", line 532, in __call__
    return create_controlled_op(
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/pennylane/ops/op_math/controlled.py", line 161, in create_controlled_op
    ctrl_op = _try_wrap_in_custom_ctrl_op(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/pennylane/ops/op_math/controlled.py", line 323, in _try_wrap_in_custom_ctrl_op
    if custom_key in ops_with_custom_ctrl_ops and all(control_values):
                                                  ^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 712, in __bool__
    return self.aval._bool(self)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 1475, in error
    raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape float64[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

The issue is that the input to the circuit, basis_state, which is a traced JAX array during compilation, is being using in Python control flow, which is not allowed. Using AutoGraph does not resolve the issue.

The appropriate changes to Catalyst and/or PennyLane should be made to add support for controlled gates in QJIT-compiled circuits, where one of the input arguments to the parameterized circuit is used as input to qml.ctrl.