We discovered this issue when attempting to QJIT-compile a circuit implementing Grover's algorithm.
Consider the following PennyLane program that applies the qml.ctrl function to a Z gate, which results in an error only when @qjit is applied:
import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit
NUM_QUBITS = 2
dev = qml.device("lightning.qubit", wires=NUM_QUBITS)
@qjit
@qml.qnode(dev)
def circuit(basis_state):
wires = list(range(NUM_QUBITS))
qml.ctrl(qml.Z(wires[-1]), control=wires[:-1], control_values=basis_state[:-1])
return qml.state()
basis_state = jnp.array([0.0, 0.0])
state = circuit(basis_state)
Traceback (most recent call last):
...
File ".../venv/lib/python3.12/site-packages/pennylane/ops/op_math/controlled.py", line 144, in ctrl
return ops_loader.ctrl(op, control, control_values=control_values, work_wires=work_wires)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/quantum_operators.py", line 319, in ctrl
return res() if isinstance(f, Operator) else res
^^^^^
File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/quantum_operators.py", line 532, in __call__
return create_controlled_op(
^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/pennylane/ops/op_math/controlled.py", line 161, in create_controlled_op
ctrl_op = _try_wrap_in_custom_ctrl_op(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/pennylane/ops/op_math/controlled.py", line 323, in _try_wrap_in_custom_ctrl_op
if custom_key in ops_with_custom_ctrl_ops and all(control_values):
^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 712, in __bool__
return self.aval._bool(self)
^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 1475, in error
raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape float64[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
The issue is that the input to the circuit, basis_state, which is a traced JAX array during compilation, is being using in Python control flow, which is not allowed. Using AutoGraph does not resolve the issue.
The appropriate changes to Catalyst and/or PennyLane should be made to add support for controlled gates in QJIT-compiled circuits, where one of the input arguments to the parameterized circuit is used as input to qml.ctrl.
We discovered this issue when attempting to QJIT-compile a circuit implementing Grover's algorithm.
Consider the following PennyLane program that applies the qml.ctrl function to a Z gate, which results in an error only when
@qjit
is applied:The issue is that the input to the circuit,
basis_state
, which is a traced JAX array during compilation, is being using in Python control flow, which is not allowed. Using AutoGraph does not resolve the issue.The appropriate changes to Catalyst and/or PennyLane should be made to add support for controlled gates in QJIT-compiled circuits, where one of the input arguments to the parameterized circuit is used as input to
qml.ctrl
.