PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.34k stars 601 forks source link

[BUG] Creating a ragged array output in `QubitDevice` errors with `default.qubit.jax` #2129

Closed antalszava closed 1 year ago

antalszava commented 2 years ago

Expected behavior

(Note: this is a specific case of ragged output creation with QubitDevice. It is meant to serve as an example that we have on file for our future considerations.)

A QNode returning multiple probability measurements with varying number of wires executes in some form (note that we're not applying the JAX interface, but using the backprop device):

dev = qml.device("default.qubit.jax", wires=3, shots=None)

@qml.qnode(dev,interface='jax')
def circuit(a, b):
    qml.RY(a, wires=0)
    qml.RX(b, wires=0)
    return qml.probs(wires=[0]), qml.probs(wires=[1,2])

circuit(0,3)

Actual behavior

~/xanadu/pennylane/pennylane/interfaces/batch/__init__.py in fn(tapes, **kwargs)
    118         def fn(tapes, **kwargs):  # pylint: disable=function-redefined
    119             tapes = [expand_fn(tape) for tape in tapes]
--> 120             return original_fn(tapes, **kwargs)
    121 
    122     @wraps(fn)

~/anaconda3/lib/python3.8/contextlib.py in inner(*args, **kwds)
     73         def inner(*args, **kwds):
     74             with self._recreate_cm():
---> 75                 return func(*args, **kwds)
     76         return inner
     77 

~/xanadu/pennylane/pennylane/_qubit_device.py in batch_execute(self, circuits)
    280             self.reset()
    281 
--> 282             res = self.execute(circuit)
    283             results.append(res)
    284 

~/xanadu/pennylane/pennylane/_qubit_device.py in execute(self, circuit, **kwargs)
    232 
    233         if (circuit.all_sampled or not circuit.is_sampled) and not multiple_sampled_jobs:
--> 234             results = self._asarray(results)
    235         else:
    236             results = tuple(self._asarray(r) for r in results)

~/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin, device)
   3614   elif isinstance(object, (list, tuple)):
   3615     if object:
-> 3616       out = stack([asarray(elt, dtype=dtype) for elt in object])
   3617     else:
   3618       out = _np_array([], dtype=dtype)

~/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in stack(arrays, axis, out)
   3375     for a in arrays:
   3376       if shape(a) != shape0:
-> 3377         raise ValueError("All input arrays must have the same shape.")
   3378       new_arrays.append(expand_dims(a, axis))
   3379     return concatenate(new_arrays, axis=axis)

ValueError: All input arrays must have the same shape.

Additional information

Likely, solving this issue will tie in with reconsidering the output type of multi-measurement QNodes. Our consideration of changing to the use of tuples would likely address this issue.

The issue boils down to the self._asarray call in QubitDevice:

~/xanadu/pennylane/pennylane/_qubit_device.py in execute(self, circuit, **kwargs)
    232 
    233         if (circuit.all_sampled or not circuit.is_sampled) and not multiple_sampled_jobs:
--> 234             results = self._asarray(results)

autograd handles the same case while emitting the following warning:

autograd/autograd/numpy/numpy_wrapper.py:77: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return _np.array(args, *array_args, **array_kwargs)

This is not something JAX can handle, hence the error.

Source code

No response

Tracebacks

No response

System information

Name: PennyLane
Version: 0.21.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/antal/xanadu/pennylane
Requires: numpy, scipy, networkx, retworkx, autograd, toml, appdirs, semantic_version, autoray, cachetools, pennylane-lightning
Required-by: PennyLane-Cirq, amazon-braket-pennylane-plugin, PennyLane-Orquestra, pennylane-qulacs, PennyLane-Honeywell, PennyLane-qiskit, PennyLane-AQT, PennyLane-PQ, PennyLane-Forest, PennyLane-qsharp, PennyLane-Qchem, PennyLane-IonQ, PennyLane-SF, PennyLane-Lightning
Platform info:           Linux-5.13.0-27-generic-x86_64-with-glibc2.10
Python version:          3.8.5
Numpy version:           1.19.2
Scipy version:           1.7.3
Installed devices:
- cirq.mixedsimulator (PennyLane-Cirq-0.19.0)
- cirq.pasqal (PennyLane-Cirq-0.19.0)
- cirq.qsim (PennyLane-Cirq-0.19.0)
- cirq.qsimh (PennyLane-Cirq-0.19.0)
- cirq.simulator (PennyLane-Cirq-0.19.0)
- braket.aws.qubit (amazon-braket-pennylane-plugin-1.5.3)
- braket.local.qubit (amazon-braket-pennylane-plugin-1.5.3)
- orquestra.forest (PennyLane-Orquestra-0.15.0)
- orquestra.ibmq (PennyLane-Orquestra-0.15.0)
- orquestra.qiskit (PennyLane-Orquestra-0.15.0)
- orquestra.qulacs (PennyLane-Orquestra-0.15.0)
- qulacs.simulator (pennylane-qulacs-0.17.0.dev0)
- honeywell.hqs (PennyLane-Honeywell-0.16.0.dev0)
- qiskit.aer (PennyLane-qiskit-0.18.0.dev0)
- qiskit.basicaer (PennyLane-qiskit-0.18.0.dev0)
- qiskit.ibmq (PennyLane-qiskit-0.18.0.dev0)
- aqt.noisy_sim (PennyLane-AQT-0.18.0)
- aqt.sim (PennyLane-AQT-0.18.0)
- projectq.classical (PennyLane-PQ-0.18.0.dev0)
- projectq.ibm (PennyLane-PQ-0.18.0.dev0)
- projectq.simulator (PennyLane-PQ-0.18.0.dev0)
- forest.numpy_wavefunction (PennyLane-Forest-0.18.0.dev0)
- forest.qvm (PennyLane-Forest-0.18.0.dev0)
- forest.wavefunction (PennyLane-Forest-0.18.0.dev0)
- microsoft.QuantumSimulator (PennyLane-qsharp-0.19.0)
- ionq.qpu (PennyLane-IonQ-0.17.0.dev0)
- ionq.simulator (PennyLane-IonQ-0.17.0.dev0)
- strawberryfields.fock (PennyLane-SF-0.20.0.dev0)
- strawberryfields.gaussian (PennyLane-SF-0.20.0.dev0)
- strawberryfields.gbs (PennyLane-SF-0.20.0.dev0)
- strawberryfields.remote (PennyLane-SF-0.20.0.dev0)
- strawberryfields.tf (PennyLane-SF-0.20.0.dev0)
- lightning.qubit (PennyLane-Lightning-0.21.0.dev0)
- default.gaussian (PennyLane-0.21.0.dev0)
- default.mixed (PennyLane-0.21.0.dev0)
- default.qubit (PennyLane-0.21.0.dev0)
- default.qubit.autograd (PennyLane-0.21.0.dev0)
- default.qubit.jax (PennyLane-0.21.0.dev0)
- default.qubit.tf (PennyLane-0.21.0.dev0)
- default.qubit.torch (PennyLane-0.21.0.dev0)


### Existing GitHub issues

- [X] I have searched existing GitHub issues to make sure the issue does not already exist.
albi3ro commented 1 year ago

Solved by new return types specification.