PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.33k stars 598 forks source link

[BUG] Incompatibility of jax with device default.mixed and qml.QubitDensityMatrix() #5196

Closed erikrecio closed 8 months ago

erikrecio commented 8 months ago

Expected behavior

jax.jit can be wrapped around a default.mixed device with qml.QubitDensityMatrix() function.

Actual behavior

Error: TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].. The error occurred while tracing the function variational_circuit_mixed at :9 for jit. This concrete value was not available in Python because it depends on the value of the argument state_ini.

Additional information

As you can see in the source code, the variational_circuit_qubit() works correctly, while variational_circuit_mixed() stumps upon an error due to the jit compilation. It's the same scenario, but with the mixed device it's not working and, as far as I understand, there should be compatibility between jax and the default.mixed device.

Source code

import pennylane as qml
import numpy as np
import jax

nqubits = 1

dev = qml.device("default.qubit", wires=nqubits)
@jax.jit
@qml.qnode(dev, interface="jax")
def variational_circuit_qubit(weights, state_ini):
    qml.QubitStateVector(state_ini, wires=range(nqubits))
    return qml.state()

dev = qml.device("default.mixed", wires=nqubits)
@jax.jit
@qml.qnode(dev, interface="jax")
def variational_circuit_mixed(weights, state_ini):
    qml.QubitDensityMatrix(state_ini, wires=range(nqubits))
    return qml.state()

state_ini = np.array([1,0])
rho_ini = np.tensordot(state_ini, state_ini, axes=0)

state_out = variational_circuit_qubit(0, state_ini)
rho_out = variational_circuit_mixed(0, rho_ini)

Tracebacks

---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
File d:\Documents\1. GitHub\qurriculum_learning\Phase Recognition\test.py:27
     24 rho_ini = np.tensordot(state_ini, state_ini, axes=0)
     26 state_out = variational_circuit_qubit(0, state_ini)
---> 27 rho_out = variational_circuit_mixed(0, rho_ini)

    [... skipping hidden 12 frame]

File d:\Documents\1. GitHub\venv\lib\site-packages\pennylane\qnode.py:1027, in QNode.__call__(self, *args, **kwargs)
   1022         full_transform_program._set_all_argnums(
   1023             self, args, kwargs, argnums
   1024         )  # pylint: disable=protected-access
   1026 # pylint: disable=unexpected-keyword-arg
-> 1027 res = qml.execute(
   1028     (self._tape,),
   1029     device=self.device,
   1030     gradient_fn=self.gradient_fn,
   1031     interface=self.interface,
   1032     transform_program=full_transform_program,
   1033     config=config,
   1034     gradient_kwargs=self.gradient_kwargs,
   1035     override_shots=override_shots,
   1036     **self.execute_kwargs,
   1037 )
   1039 res = res[0]
   1041 # convert result to the interface in case the qfunc has no parameters

File d:\Documents\1. GitHub\venv\lib\site-packages\pennylane\interfaces\execution.py:616, in execute(tapes, device, gradient_fn, interface, transform_program, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform)
    614 # Exiting early if we do not need to deal with an interface boundary
    615 if no_interface_boundary_required:
--> 616     results = inner_execute(tapes)
    617     return post_processing(results)
    619 _grad_on_execution = False

File d:\Documents\1. GitHub\venv\lib\site-packages\pennylane\interfaces\execution.py:249, in _make_inner_execute.<locals>.inner_execute(tapes, **_)
    247 if numpy_only:
    248     tapes = tuple(qml.transforms.convert_to_numpy_parameters(t) for t in tapes)
--> 249 return cached_device_execution(tapes)

File d:\Documents\1. GitHub\venv\lib\site-packages\pennylane\interfaces\execution.py:371, in cache_execute.<locals>.wrapper(tapes, **kwargs)
    366         return (res, []) if return_tuple else res
    368 else:
    369     # execute all unique tapes that do not exist in the cache
    370     # convert to list as new device interface returns a tuple
--> 371     res = list(fn(tuple(execution_tapes.values()), **kwargs))
    373 final_res = []
    375 for i, tape in enumerate(tapes):

File C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.10_3.10.3056.0_x64__qbz5n2kfra8p0\lib\contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File d:\Documents\1. GitHub\venv\lib\site-packages\pennylane\_qubit_device.py:460, in QubitDevice.batch_execute(self, circuits)
    455 for circuit in circuits:
    456     # we need to reset the device here, else it will
    457     # not start the next computation in the zero state
    458     self.reset()
--> 460     res = self.execute(circuit)
    461     results.append(res)
    463 if self.tracker.active:

File d:\Documents\1. GitHub\venv\lib\site-packages\pennylane\devices\default_mixed.py:685, in DefaultMixed.execute(self, circuit, **kwargs)
    683         wires_list.append(m.wires)
    684     self.measured_wires = qml.wires.Wires.all_wires(wires_list)
--> 685 return super().execute(circuit, **kwargs)

File d:\Documents\1. GitHub\venv\lib\site-packages\pennylane\_qubit_device.py:279, in QubitDevice.execute(self, circuit, **kwargs)
    276 self.check_validity(circuit.operations, circuit.observables)
    278 # apply all circuit operations
--> 279 self.apply(circuit.operations, rotations=self._get_diagonalizing_gates(circuit), **kwargs)
    281 # generate computational basis samples
    282 if self.shots is not None or circuit.is_sampled:

File d:\Documents\1. GitHub\venv\lib\site-packages\pennylane\devices\default_mixed.py:699, in DefaultMixed.apply(self, operations, rotations, **kwargs)
    693         raise DeviceError(
    694             f"Operation {operation.name} cannot be used after other Operations have already been applied "
    695             f"on a {self.short_name} device."
    696         )
    698 for operation in operations:
--> 699     self._apply_operation(operation)
    701 # store the pre-rotated state
    702 self._pre_rotated_state = self._state

File d:\Documents\1. GitHub\venv\lib\site-packages\pennylane\devices\default_mixed.py:604, in DefaultMixed._apply_operation(self, operation)
    601     return
    603 if isinstance(operation, QubitDensityMatrix):
--> 604     self._apply_density_matrix(operation.parameters[0], wires)
    605     return
    607 if isinstance(operation, Snapshot):

File d:\Documents\1. GitHub\venv\lib\site-packages\pennylane\devices\default_mixed.py:540, in DefaultMixed._apply_density_matrix(self, state, device_wires)
    537 if dm_dim != state.shape[0]:
    538     raise ValueError("Density matrix must be of length (2**wires, 2**wires)")
--> 540 if not qnp.allclose(
    541     qnp.trace(qnp.reshape(state, (state_dim, state_dim))), 1.0, atol=tolerance
    542 ):
    543     raise ValueError("Trace of density matrix is not equal one.")
    545 if len(device_wires) == self.num_wires and sorted(device_wires.labels) == list(
    546     device_wires.labels
    547 ):
    548     # Initialize the entire wires with the state

    [... skipping hidden 1 frame]

File d:\Documents\1. GitHub\venv\lib\site-packages\jax\_src\core.py:1510, in concretization_function_error.<locals>.error(self, arg)
   1509 def error(self, arg):
-> 1510   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function variational_circuit_mixed at <ipython-input-15-745eb18162e5>:16 for jit. This concrete value was not available in Python because it depends on the value of the argument state_ini.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

System information

Name: PennyLane
Version: 0.33.1
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: d:\documents\1. github\venv\lib\site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Lightning

Platform info:           Windows-10-10.0.19045-SP0
Python version:          3.10.11
Numpy version:           1.26.2
Scipy version:           1.11.4
Installed devices:
- default.gaussian (PennyLane-0.33.1)
- default.mixed (PennyLane-0.33.1)
- default.qubit (PennyLane-0.33.1)
- default.qubit.autograd (PennyLane-0.33.1)
- default.qubit.jax (PennyLane-0.33.1)
- default.qubit.legacy (PennyLane-0.33.1)
- default.qubit.tf (PennyLane-0.33.1)
- default.qubit.torch (PennyLane-0.33.1)
- default.qutrit (PennyLane-0.33.1)
- null.qubit (PennyLane-0.33.1)
- lightning.qubit (PennyLane-Lightning-0.33.1)

Existing GitHub issues

timmysilv commented 8 months ago

hi @erikrecio, thanks for reporting this! I've opened a bugfix (linked above), and I'll let you know once it's accepted.

erikrecio commented 8 months ago

Do I have to wait until the next pennylane update or does this mean my code will run now? I'm not aware of how to download "the latest version", since "pip install --upgrade pennylane" didn't fix the issue.

timmysilv commented 8 months ago

hello again! you're correct, the latest pennylane installed from pip does not yet have the fix. Two things I'd recommend:

  1. Install the latest master of PennyLane to have the fix included using pip install git+https://github.com/PennyLaneAI/pennylane.git, or get a local editable version of PennyLane
  2. wait for PL v0.36 to be released, expected on March 5

happy hacking!

erikrecio commented 8 months ago

This is great, works like a charm and you solved it really quick. Thanks @timmysilv!!