PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.38k stars 607 forks source link

[BUG] `KerasLayer` has unintended side effects on its QNode #5723

Closed isaacdevlugt closed 5 months ago

isaacdevlugt commented 6 months ago

https://app.shortcut.com/xanaduai/story/63723/bug-keraslayer-has-unintended-side-effects-on-its-qnode

Expected behavior

Defining a Keras layer from a QNode has no side effects on the QNode itself.

Actual behavior

QNodes in a KerasLayer are mutated.

Additional information

No response

Source code

import pennylane as qml
import torch

dev = qml.device('default.qubit')

@qml.qnode(dev)
def circuit(inputs, weights):
    qml.AmplitudeEmbedding(inputs, wires=[0, 1], normalize=True)
    qml.RY(weights[0], wires=0)
    qml.RY(weights[1], wires=1)
    return qml.vn_entropy(wires=[1])

weight_shapes = {"weights": (2,)}

qlayer_torch = qml.qnn.TorchLayer(circuit, weight_shapes=weight_shapes)
qlayer_keras = qml.qnn.KerasLayer(circuit, weight_shapes=weight_shapes, output_dim=1)

inputs = torch.rand(4, requires_grad=False)

clayer = torch.nn.Softmax()
qlayer_torch(clayer(inputs))
model = torch.nn.Sequential(clayer, qlayer_torch)

model(inputs)

Tracebacks

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 21
     18 inputs = torch.rand(4, requires_grad=False)
     20 clayer = torch.nn.Softmax()
---> 21 qlayer_torch(clayer(inputs))
     22 model = torch.nn.Sequential(clayer, qlayer_torch)
     24 model(inputs)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/qnn/torch.py:402, in TorchLayer.forward(self, inputs)
    399     inputs = torch.reshape(inputs, (-1, inputs.shape[-1]))
    401 # calculate the forward pass as usual
--> 402 results = self._evaluate_qnode(inputs)
    404 if isinstance(results, tuple):
    405     if has_batch_dim:

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/qnn/torch.py:428, in TorchLayer._evaluate_qnode(self, x)
    416 """Evaluates the QNode for a single input datapoint.
    417 
    418 Args:
   (...)
    422     tensor: output datapoint
    423 """
    424 kwargs = {
    425     **{self.input_arg: x},
    426     **{arg: weight.to(x) for arg, weight in self.qnode_weights.items()},
    427 }
--> 428 res = self.qnode(**kwargs)
    430 if isinstance(res, torch.Tensor):
    431     return res.type(x.dtype)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/workflow/qnode.py:1098, in QNode.__call__(self, *args, **kwargs)
   1095 self._update_gradient_fn(shots=override_shots, tape=self._tape)
   1097 try:
-> 1098     res = self._execution_component(args, kwargs, override_shots=override_shots)
   1099 finally:
   1100     if old_interface == "auto":

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/workflow/qnode.py:1052, in QNode._execution_component(self, args, kwargs, override_shots)
   1049 full_transform_program.prune_dynamic_transform()
   1051 # pylint: disable=unexpected-keyword-arg
-> 1052 res = qml.execute(
   1053     (self._tape,),
   1054     device=self.device,
   1055     gradient_fn=self.gradient_fn,
   1056     interface=self.interface,
   1057     transform_program=full_transform_program,
   1058     config=config,
   1059     gradient_kwargs=self.gradient_kwargs,
   1060     override_shots=override_shots,
   1061     **self.execute_kwargs,
   1062 )
   1063 res = res[0]
   1065 # convert result to the interface in case the qfunc has no parameters

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/workflow/execution.py:616, in execute(tapes, device, gradient_fn, interface, transform_program, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp)
    614 # Exiting early if we do not need to deal with an interface boundary
    615 if no_interface_boundary_required:
--> 616     results = inner_execute(tapes)
    617     return post_processing(results)
    619 _grad_on_execution = False

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/workflow/execution.py:297, in _make_inner_execute.<locals>.inner_execute(tapes, **_)
    294 transformed_tapes, transform_post_processing = transform_program(tapes)
    296 if transformed_tapes:
--> 297     results = device_execution(transformed_tapes)
    298 else:
    299     results = ()

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/modifiers/simulator_tracking.py:30, in _track_execute.<locals>.execute(self, circuits, execution_config)
     28 @wraps(untracked_execute)
     29 def execute(self, circuits, execution_config=DefaultExecutionConfig):
---> 30     results = untracked_execute(self, circuits, execution_config)
     31     if isinstance(circuits, QuantumScript):
     32         batch = (circuits,)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/modifiers/single_tape_support.py:32, in _make_execute.<locals>.execute(self, circuits, execution_config)
     30     is_single_circuit = True
     31     circuits = (circuits,)
---> 32 results = batch_execute(self, circuits, execution_config)
     33 return results[0] if is_single_circuit else results

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/default_qubit.py:593, in DefaultQubit.execute(self, circuits, execution_config)
    590 prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
    592 if max_workers is None:
--> 593     return tuple(
    594         _simulate_wrapper(
    595             c,
    596             {
    597                 "rng": self._rng,
    598                 "debugger": self._debugger,
    599                 "interface": interface,
    600                 "state_cache": self._state_cache,
    601                 "prng_key": _key,
    602             },
    603         )
    604         for c, _key in zip(circuits, prng_keys)
    605     )
    607 vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
    608 seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/default_qubit.py:594, in <genexpr>(.0)
    590 prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
    592 if max_workers is None:
    593     return tuple(
--> 594         _simulate_wrapper(
    595             c,
    596             {
    597                 "rng": self._rng,
    598                 "debugger": self._debugger,
    599                 "interface": interface,
    600                 "state_cache": self._state_cache,
    601                 "prng_key": _key,
    602             },
    603         )
    604         for c, _key in zip(circuits, prng_keys)
    605     )
    607 vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
    608 seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/default_qubit.py:841, in _simulate_wrapper(circuit, kwargs)
    840 def _simulate_wrapper(circuit, kwargs):
--> 841     return simulate(circuit, **kwargs)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/qubit/simulate.py:287, in simulate(circuit, debugger, state_cache, **execution_kwargs)
    282     return simulate_one_shot_native_mcm(
    283         circuit, debugger=debugger, rng=rng, prng_key=prng_key, interface=interface
    284     )
    286 ops_key, meas_key = jax_random_split(prng_key)
--> 287 state, is_state_batched = get_final_state(
    288     circuit, debugger=debugger, rng=rng, prng_key=ops_key, interface=interface
    289 )
    290 if state_cache is not None:
    291     state_cache[circuit.hash] = state

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/qubit/simulate.py:150, in get_final_state(circuit, debugger, **execution_kwargs)
    148 if isinstance(op, MidMeasureMP):
    149     prng_key, key = jax_random_split(prng_key)
--> 150 state = apply_operation(
    151     op,
    152     state,
    153     is_state_batched=is_state_batched,
    154     debugger=debugger,
    155     mid_measurements=mid_measurements,
    156     rng=rng,
    157     prng_key=key,
    158 )
    159 # Handle postselection on mid-circuit measurements
    160 if isinstance(op, qml.Projector):

File /opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/qubit/apply_operation.py:206, in apply_operation(op, state, is_state_batched, debugger, **_)
    152 @singledispatch
    153 def apply_operation(
    154     op: qml.operation.Operator,
   (...)
    158     **_,
    159 ):
    160     """Apply and operator to a given state.
    161 
    162     Args:
   (...)
    204 
    205     """
--> 206     return _apply_operation_default(op, state, is_state_batched, debugger)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/qubit/apply_operation.py:216, in _apply_operation_default(op, state, is_state_batched, debugger)
    210 """The default behaviour of apply_operation, accessed through the standard dispatch
    211 of apply_operation, as well as conditionally in other dispatches."""
    212 if (
    213     len(op.wires) < EINSUM_OP_WIRECOUNT_PERF_THRESHOLD
    214     and math.ndim(state) < EINSUM_STATE_WIRECOUNT_PERF_THRESHOLD
    215 ) or (op.batch_size and is_state_batched):
--> 216     return apply_operation_einsum(op, state, is_state_batched=is_state_batched)
    217 return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/qubit/apply_operation.py:102, in apply_operation_einsum(op, state, is_state_batched)
     99         op._batch_size = batch_size  # pylint:disable=protected-access
    100 reshaped_mat = math.reshape(mat, new_mat_shape)
--> 102 return math.einsum(einsum_indices, reshaped_mat, state)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/math/multi_dispatch.py:547, in einsum(indices, like, optimize, *operands)
    501 """Evaluates the Einstein summation convention on the operands.
    502 
    503 Args:
   (...)
    544 array([ 30,  80, 130, 180, 230])
    545 """
    546 if like is None:
--> 547     like = get_interface(*operands)
    548 operands = np.coerce(operands, like=like)
    549 if optimize is None or like == "torch":
    550     # torch einsum doesn't support the optimize keyword argument

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/math/utils.py:221, in get_interface(*values)
    217 interfaces = {_get_interface_of_single_tensor(v) for v in values}
    219 if len(interfaces - {"numpy", "scipy", "autograd"}) > 1:
    220     # contains multiple non-autograd interfaces
--> 221     raise ValueError("Tensors contain mixed types; cannot determine dispatch library")
    223 non_numpy_scipy_interfaces = set(interfaces) - {"numpy", "scipy"}
    225 if len(non_numpy_scipy_interfaces) > 1:
    226     # contains autograd and another interface

ValueError: Tensors contain mixed types; cannot determine dispatch library

System information

Name: PennyLane
Version: 0.36.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /Users/isaac/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-qiskit, PennyLane_Lightning

Platform info:           macOS-14.5-arm64-arm-64bit
Python version:          3.11.8
Numpy version:           1.26.4
Scipy version:           1.13.0
Installed devices:
- default.clifford (PennyLane-0.36.0)
- default.gaussian (PennyLane-0.36.0)
- default.mixed (PennyLane-0.36.0)
- default.qubit (PennyLane-0.36.0)
- default.qubit.autograd (PennyLane-0.36.0)
- default.qubit.jax (PennyLane-0.36.0)
- default.qubit.legacy (PennyLane-0.36.0)
- default.qubit.tf (PennyLane-0.36.0)
- default.qubit.torch (PennyLane-0.36.0)
- default.qutrit (PennyLane-0.36.0)
- default.qutrit.mixed (PennyLane-0.36.0)
- null.qubit (PennyLane-0.36.0)
- qiskit.aer (PennyLane-qiskit-0.36.0)
- qiskit.basicaer (PennyLane-qiskit-0.36.0)
- qiskit.basicsim (PennyLane-qiskit-0.36.0)
- qiskit.ibmq (PennyLane-qiskit-0.36.0)
- qiskit.ibmq.circuit_runner (PennyLane-qiskit-0.36.0)
- qiskit.ibmq.sampler (PennyLane-qiskit-0.36.0)
- qiskit.remote (PennyLane-qiskit-0.36.0)
- lightning.qubit (PennyLane_Lightning-0.36.0)

Existing GitHub issues

albi3ro commented 6 months ago

This is due to:

https://github.com/PennyLaneAI/pennylane/blob/0dcba4495af81890f8d76e6a5715f6994505f560/pennylane/qnn/keras.py#L334

where it sets the qnode's interface to tf. We could potentially remove this line and replace it with a validation check that ensures the interface is in auto/ tf.