Open ankit27kh opened 2 years ago
Hi @ankit27kh, which version of JAX and JAXLib does the error arise with?
I'm not able to reproduce on Python 3.9 with the following packages:
-> pip freeze | grep 'numpy\|PennyLane\|jax\|scipy'
jax==0.3.4
jaxlib==0.3.2
numpy==1.22.3
PennyLane==0.22.1
PennyLane-Cirq==0.17.1
PennyLane-Lightning==0.22.1
scipy==1.8.0
Hi @antalszava, here are my module versions:
jax==0.2.26
jaxlib==0.1.75+cuda11.cudnn805
numpy==1.22.3
PennyLane==0.22.1
PennyLane-Lightning==0.22.0
PennyLane-qiskit==0.20.0
scipy==1.8.0
I think this is the most updated version for cuda. I tried updating but got the same.
Hi @ankit27kh, I see here that jaxlib v0.3.2 works with your version of cuda. I'd guess that jax v0.3.4 will also work. Are you able to update jax and jaxlib? Or do you run into problems?
Hello @CatalinaAlbornoz and @antalszava, I updated my libraries. The issue is still there. Here are my latest versions:
jax==0.3.4
jaxlib==0.3.2+cuda11.cudnn82
numpy==1.22.3
PennyLane==0.22.2
PennyLane-Lightning==0.22.0
PennyLane-qiskit==0.20.0
scipy==1.8.0
Using the code with jit
does not terminate on GPU. If I use config.update('jax_platform_name', 'cpu')
, I get a result. Without jit
it works on both CPU and GPU.
We're looking into this. Thanks for reporting it @ankit27kh .
Hi @ankit27kh thanks for the additional information. We are able to reproduce the bug with cuda enabled jaxlib, and will work on a solution. Just wondering issue is this a blocker for you currently? Trying to determine priority for the fix
Hey @puzzleshark, this was just something I encountered. Not blocking my work.
Expected behavior
This code should terminate with a result:
Actual behavior
Code does not terminate
Additional information
Code terminates when not using jit:
Source code
No response
Tracebacks
No response
System information
Existing GitHub issues