Open josh146 opened 3 days ago
I think I have found an edge case, where static_argnums is used but the underlying parameter still becomes a tracer when it enters a QNode:
static_argnums
import pennylane as qml from jax import numpy as jnp from catalyst import qjit, measure dev = qml.device('lightning.qubit', wires=1) @qjit(static_argnums=(1,)) def f(x, c): print(c) @qml.qnode(dev) def circuit(x, c): print(c) qml.RY(c, 0) qml.RX(x, 0) return qml.expval(qml.PauliZ(0)) return circuit(x, c)
0.5 Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=2/1)> array(0.80830707)
My expectation is that c should be a concrete value, even inside the QNode.
c
I think I have found an edge case, where
static_argnums
is used but the underlying parameter still becomes a tracer when it enters a QNode:My expectation is that
c
should be a concrete value, even inside the QNode.