PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

Dynamic one shot does not work with ZNE #929

Open dime10 opened 1 month ago

dime10 commented 1 month ago
import jax.numpy as jnp
import pennylane as qml
from catalyst import *

@qml.qnode(dev, mcm_method="one-shot")
def circuit():
    qml.Hadamard(wires=0)
    qml.CNOT(wires=[0, 1])
    qml.Hadamard(wires=0)
    qml.CNOT(wires=[0, 1])
    qml.Hadamard(wires=0)
    return qml.expval(qml.PauliY(0))

@qjit
def mitigated_circuit():
    s = jax.numpy.array([1, 2])
    return mitigate_with_zne(circuit, scale_factors=s)()
  File "catalyst/frontend/catalyst/jax_primitives.py", line 739, in _zne_lowering
    _func_lowering(ctx, *args, call_jaxpr=jaxpr.eqns[0].params["call_jaxpr"], fn=fn, call=False)
                                          ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
KeyError: 'call_jaxpr'
dime10 commented 1 month ago

Briefly looked into this issue, and the root cause is that the ZNE wrapper only works directly on QNodes, but the dynamic_one_shot transform replaces a QNode with a classical function that invokes a QNode, thus breaking the ZNE assumption.

I think a solution to this problem might be to rewrite ZNE to be part of the QNode transform program, that way it would be properly propagated by the dynamic_one_shot method. Haven't verified if this would actually work though.