PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.34k stars 600 forks source link

[BUG] `qml.pauli_decompose` fails when being jit-ted #5817

Closed paul0403 closed 4 months ago

paul0403 commented 5 months ago

Expected behavior

Pauli decompose works well outside of jax.jit

dev = qml.device("lightning.qubit", wires=1)

#@jax.jit
def f(mat):
    coeffs, unitaries = qml.pauli_decompose(mat, check_hermitian=False).terms()
    return coeffs, unitaries

x = jnp.array([[1,2],[3,4] ])
print(f(x))

>>>
(Array([ 2.5+0.j ,  2.5+0.j , -0. -0.5j, -1.5+0.j ], dtype=complex128), [I(0), X(0), Y(0), Z(0)])

Actual behavior

but fails when jit-ted:

dev = qml.device("lightning.qubit", wires=1)

@jax.jit
def f(mat):
    coeffs, unitaries = qml.pauli_decompose(mat, check_hermitian=False).terms()
    return coeffs, unitaries

x = jnp.array([[1,2],[3,4] ])
print(f(x))

>>>
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/paul.wang/small_playgrounds_dump/qsvt_demo/pauli_decompose.py", line 20, in <module>
    print(f(x))
  File "/home/paul.wang/small_playgrounds_dump/qsvt_demo/pauli_decompose.py", line 16, in f
    coeffs, unitaries = qml.pauli_decompose(mat, check_hermitian=False).terms()
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/pauli/conversion.py", line 323, in pauli_decompose
    coeffs, obs = _generalized_pauli_decompose(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/pauli/conversion.py", line 209, in _generalized_pauli_decompose
    if not qml.math.allclose(coefficient, 0):
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /home/paul.wang/small_playgrounds_dump/qsvt_demo/pauli_decompose.py:14 for jit. This concrete value was not available in Python because it depends on the value of the argument mat.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Note that the input matrix cannot be set as static_argnum as it is a (j)numpy array, which is not hashable.

This error is the usual "you are using an abstract traced value as a concrete value" error in jax, which likely is because there is a conditional depending on the concrete value of the input matrix; however, turning off check_hermitian did not help.

Additional information

No response

Source code

No response

Tracebacks

No response

System information

Name: PennyLane
Version: 0.37.0.dev0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/paul.wang/.local/lib/python3.10/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: amazon-braket-pennylane-plugin, PennyLane-Catalyst, PennyLane-Lightning-Kokkos, PennyLane_Lightning

Platform info:           Linux-6.5.0-35-generic-x86_64-with-glibc2.35
Python version:          3.10.12
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- default.clifford (PennyLane-0.37.0.dev0)
- default.gaussian (PennyLane-0.37.0.dev0)
- default.mixed (PennyLane-0.37.0.dev0)
- default.qubit (PennyLane-0.37.0.dev0)
- default.qubit.autograd (PennyLane-0.37.0.dev0)
- default.qubit.jax (PennyLane-0.37.0.dev0)
- default.qubit.legacy (PennyLane-0.37.0.dev0)
- default.qubit.tf (PennyLane-0.37.0.dev0)
- default.qubit.torch (PennyLane-0.37.0.dev0)
- default.qutrit (PennyLane-0.37.0.dev0)
- default.qutrit.mixed (PennyLane-0.37.0.dev0)
- default.tensor (PennyLane-0.37.0.dev0)
- null.qubit (PennyLane-0.37.0.dev0)
- nvidia.custatevec (PennyLane-Catalyst-0.7.0.dev0)
- nvidia.cutensornet (PennyLane-Catalyst-0.7.0.dev0)
- oqc.cloud (PennyLane-Catalyst-0.7.0.dev0)
- softwareq.qpp (PennyLane-Catalyst-0.7.0.dev0)
- lightning.qubit (PennyLane_Lightning-0.36.0)
- braket.aws.ahs (amazon-braket-pennylane-plugin-1.27.0)
- braket.aws.qubit (amazon-braket-pennylane-plugin-1.27.0)
- braket.local.ahs (amazon-braket-pennylane-plugin-1.27.0)
- braket.local.qubit (amazon-braket-pennylane-plugin-1.27.0)
- lightning.kokkos (PennyLane-Lightning-Kokkos-0.32.0)

Existing GitHub issues

obliviateandsurrender commented 4 months ago

This could be fixed by modifying the following if condition here to do the following:

        # following is_abstract should ensure `allclose` is not called for jax-jit or tf.function case.
        if not is_abstract(matrix) and qml.math.allclose(coefficient, 0):
            continue

        observables = (
            [(o, w) for w, o in zip(wire_order, pauli_rep) if o != I]
            if hide_identity and not all(t == I for t in pauli_rep)
            else [(o, w) for w, o in zip(wire_order, pauli_rep)]
        )

        if observables:
            coeffs.append(coefficient)
            obs.append(observables)