PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
101 stars 26 forks source link

[Bug] Static argnum is not correctly making QNode arguments compile-time values #902

Open josh146 opened 3 days ago

josh146 commented 3 days ago

I think I have found an edge case, where static_argnums is used but the underlying parameter still becomes a tracer when it enters a QNode:

import pennylane as qml
from jax import numpy as jnp
from catalyst import qjit, measure

dev = qml.device('lightning.qubit', wires=1)

@qjit(static_argnums=(1,))
def f(x, c):
    print(c)

    @qml.qnode(dev)
    def circuit(x, c):
        print(c)
        qml.RY(c, 0)
        qml.RX(x, 0)
        return qml.expval(qml.PauliZ(0))

    return circuit(x, c)
0.5
Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=2/1)>
array(0.80830707)

My expectation is that c should be a concrete value, even inside the QNode.