PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.25k stars 584 forks source link

[BUG] `qml.exp` does not work with `jax.jit` and workflows that require decomposition, since `Exp.decomposition` is not traceable #5993

Closed dime10 closed 2 weeks ago

dime10 commented 1 month ago

Expected behavior

Originally reported in https://github.com/PennyLaneAI/catalyst/issues/923

Actual behavior

The following circuit raises an error:

import jax
import jax.numpy as jnp
import pennylane as qml

@jax.jit
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(theta):
    qml.exp(qml.X(0), coeff=-1j * jnp.cos(theta))
    return qml.expval(qml.Z(0))

circuit(jnp.pi / 3)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pennylane/ops/op_math/exp.py", line 473, in has_generator
    return self.base.is_hermitian and not np.real(self.coeff)
                                      ^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape float64[]..

Additional information

No response

Source code

No response

Tracebacks

No response

System information

-

Existing GitHub issues

trbromley commented 1 month ago

Thanks @dime10! How urgent would you say this is to fix?

josh146 commented 1 month ago

Hey @trbromley, I would likely label this a P0/High Priority, as it is blocking an external researcher

Qottmann commented 1 month ago

Might this be a lightning bug? Because it works with default.qubit

dev = qml.device("default.qubit", wires=1)

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit1(theta):
    qml.exp(qml.X(0), -1j * jnp.cos(theta))
    return qml.expval(qml.Z(0))

circuit1(jnp.pi / 3.)
Array(0.54030231, dtype=float64)
dime10 commented 1 month ago

Might this be a lightning bug? Because it works with default.qubit

dev = qml.device("default.qubit", wires=1)

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit1(theta):
    qml.exp(qml.X(0), -1j * jnp.cos(theta))
    return qml.expval(qml.Z(0))

circuit1(jnp.pi / 3.)
Array(0.54030231, dtype=float64)

In that case the difference could be in how the gate is executed. Probably the matrix method is traceable but not the decomposition method, I believe lightning favours the latter whereas default favours the former.

josh146 commented 1 month ago

@Qottmann @dime10 I can confirm that the issue only occurs when combining jax.jit with workflows that trigger Exp.decomposition.

The underlying issue is afaik the Exp._recursive_decomposition and Exp.has_generator methods.

In both methods, Python boolean expressions are being made with qml.math.real(coeff). This works fine without JIT, but when using JIT, qml.math.real(coeff) will be a dynamic variable and Python does not know how to convert this into a bool.

A couple of notes: