Closed paul0403 closed 4 months ago
This could be fixed by modifying the following if
condition here to do the following:
# following is_abstract should ensure `allclose` is not called for jax-jit or tf.function case.
if not is_abstract(matrix) and qml.math.allclose(coefficient, 0):
continue
observables = (
[(o, w) for w, o in zip(wire_order, pauli_rep) if o != I]
if hide_identity and not all(t == I for t in pauli_rep)
else [(o, w) for w, o in zip(wire_order, pauli_rep)]
)
if observables:
coeffs.append(coefficient)
obs.append(observables)
Expected behavior
Pauli decompose works well outside of
jax.jit
Actual behavior
but fails when jit-ted:
Note that the input matrix cannot be set as
static_argnum
as it is a (j)numpy array, which is not hashable.This error is the usual "you are using an abstract traced value as a concrete value" error in jax, which likely is because there is a conditional depending on the concrete value of the input matrix; however, turning off
check_hermitian
did not help.Additional information
No response
Source code
No response
Tracebacks
No response
System information
Existing GitHub issues