PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.35k stars 603 forks source link

[BUG] Jax compiled default.qubit.jax device raises ConversionError for qml.QubitStateVector #1670

Closed bonfab closed 3 years ago

bonfab commented 3 years ago

Expected behavior

The _apply_state_vector method seems to not have been adapted to be compatible with jax compiled code when setting a state vector with qml.QubitStateVector.

Actual behavior

Specifically in

if not np.allclose(np.linalg.norm(state, ord=2), 1.0, atol=tolerance):
            raise ValueError("Sum of amplitudes-squared does not equal one.")

where the norm of the state vector is calculated with np.linalg.norm raises a jax._src.errors.TracerArrayConversionError.

A solution could be to use the jax.numpy version instead: jnp.linalg.norm

Additional information

No response

Source code

import pennylane as qml
import jax
import numpy as np

def circuit(x):
    wires = list(range(2))
    qml.QubitStateVector(x, wires=wires)
    return [qml.expval(qml.PauliX(wires=i)) for i in wires]

dev = qml.device("default.qubit.jax", wires=list(range(2)))

qnode = jax.jit(qml.QNode(circuit, dev, interface="jax"))

state_vector = np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])

f_norm = jax.jit(jax.numpy.linalg.norm) # works

#f_norm = jax.jit(np.linalg.norm) # does not work, raises same error

print(f_norm(state_vector))

qnode(state_vector)

Tracebacks

No response

System information

Name: PennyLane
Version: 0.17.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/fabian/.local/lib/python3.8/site-packages
Requires: appdirs, networkx, semantic-version, numpy, scipy, toml, autoray, autograd
Required-by: pennylane-qulacs, PennyLane-qiskit
Platform info:           Linux-5.11.0-34-generic-x86_64-with-glibc2.29
Python version:          3.8.10
Numpy version:           1.19.2
Scipy version:           1.7.1
Installed devices:
- default.gaussian (PennyLane-0.17.0)
- default.mixed (PennyLane-0.17.0)
- default.qubit (PennyLane-0.17.0)
- default.qubit.autograd (PennyLane-0.17.0)
- default.qubit.jax (PennyLane-0.17.0)
- default.qubit.tf (PennyLane-0.17.0)
- default.tensor (PennyLane-0.17.0)
- default.tensor.tf (PennyLane-0.17.0)
- qulacs.simulator (pennylane-qulacs-0.15.0)
- qiskit.aer (PennyLane-qiskit-0.17.0)
- qiskit.basicaer (PennyLane-qiskit-0.17.0)
- qiskit.ibmq (PennyLane-qiskit-0.17.0)

CatalinaAlbornoz commented 3 years ago

Hi @bonfab! Thank you for reporting this bug. We'll get on it and try to fix it!

josh146 commented 3 years ago

One approach to fix this could be to make sure that the qml.math.linalg.norm and qml.math.allclose functions both work with the JAX jit --- once this is the case, we can modify this default.qubit method to use these functions instead

bonfab commented 3 years ago

I realized it might be not as trivial to fix as first thought. Even after adapting the source code to jax.numpy one still receives a ConcretizationTypeError. Only workaround working for me at the moment is to comment out the check completely.

CatalinaAlbornoz commented 3 years ago

Yes @bonfab, I was also getting an error. I think Josh's approach is a good way to go. If you have any other ideas on how to fix this bug let us know here!

antalszava commented 3 years ago

Hi @bonfab, with #1683 merged, this should be resolved in the master branch.