PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.27k stars 585 forks source link

[BUG] Infinite run with JAX-JIT and `qml.expval` #2405

Open ankit27kh opened 2 years ago

ankit27kh commented 2 years ago

Expected behavior

This code should terminate with a result:

dev = qml.device('default.qubit.jax', wires=2, shots=10)

@jax.jit
@qml.qnode(dev, interface='jax')
def circ(x):
    qml.PauliZ(wires=0)
    qml.RY(x, wires=0)
    return qml.expval(qml.PauliZ(0))

print(circ(1))

Actual behavior

Code does not terminate

Additional information

Code terminates when not using jit:

dev = qml.device('default.qubit.jax', wires=2, shots=10)

@qml.qnode(dev, interface='jax')
def circ(x):
    qml.PauliZ(wires=0)
    qml.RY(x, wires=0)
    return qml.expval(qml.PauliZ(0))

print(circ(1))
>>> 0.4

Source code

No response

Tracebacks

No response

System information

Name: PennyLane
Version: 0.22.1
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: lib/python3.9/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic-version, toml
Required-by: PennyLane-Lightning, PennyLane-qiskit
Platform info:           Linux-5.13.0-39-generic-x86_64-with-glibc2.31
Python version:          3.9.7
Numpy version:           1.22.3
Scipy version:           1.8.0
Installed devices:
- default.gaussian (PennyLane-0.22.1)
- default.mixed (PennyLane-0.22.1)
- default.qubit (PennyLane-0.22.1)
- default.qubit.autograd (PennyLane-0.22.1)
- default.qubit.jax (PennyLane-0.22.1)
- default.qubit.tf (PennyLane-0.22.1)
- default.qubit.torch (PennyLane-0.22.1)
- qiskit.aer (PennyLane-qiskit-0.20.0)
- qiskit.basicaer (PennyLane-qiskit-0.20.0)
- qiskit.ibmq (PennyLane-qiskit-0.20.0)
- lightning.qubit (PennyLane-Lightning-0.22.0)

Existing GitHub issues

antalszava commented 2 years ago

Hi @ankit27kh, which version of JAX and JAXLib does the error arise with?

I'm not able to reproduce on Python 3.9 with the following packages:

-> pip freeze | grep 'numpy\|PennyLane\|jax\|scipy'
jax==0.3.4
jaxlib==0.3.2
numpy==1.22.3
PennyLane==0.22.1
PennyLane-Cirq==0.17.1
PennyLane-Lightning==0.22.1
scipy==1.8.0
ankit27kh commented 2 years ago

Hi @antalszava, here are my module versions:

jax==0.2.26
jaxlib==0.1.75+cuda11.cudnn805
numpy==1.22.3
PennyLane==0.22.1
PennyLane-Lightning==0.22.0
PennyLane-qiskit==0.20.0
scipy==1.8.0

I think this is the most updated version for cuda. I tried updating but got the same.

CatalinaAlbornoz commented 2 years ago

Hi @ankit27kh, I see here that jaxlib v0.3.2 works with your version of cuda. I'd guess that jax v0.3.4 will also work. Are you able to update jax and jaxlib? Or do you run into problems?

ankit27kh commented 2 years ago

Hello @CatalinaAlbornoz and @antalszava, I updated my libraries. The issue is still there. Here are my latest versions:

jax==0.3.4
jaxlib==0.3.2+cuda11.cudnn82
numpy==1.22.3
PennyLane==0.22.2
PennyLane-Lightning==0.22.0
PennyLane-qiskit==0.20.0
scipy==1.8.0

Using the code with jit does not terminate on GPU. If I use config.update('jax_platform_name', 'cpu'), I get a result. Without jit it works on both CPU and GPU.

CatalinaAlbornoz commented 2 years ago

We're looking into this. Thanks for reporting it @ankit27kh .

puzzleshark commented 2 years ago

Hi @ankit27kh thanks for the additional information. We are able to reproduce the bug with cuda enabled jaxlib, and will work on a solution. Just wondering issue is this a blocker for you currently? Trying to determine priority for the fix

ankit27kh commented 2 years ago

Hey @puzzleshark, this was just something I encountered. Not blocking my work.