PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.3k stars 592 forks source link

How to support Prod observables in custom devices? #5707

Open cvjjm opened 4 months ago

cvjjm commented 4 months ago

Feature details

This is not really a feature request, rather a report of an unexpected breaking change for device developers that I do not know how to fix...

Up until v0.35.1, when a custom device was called to compute expectation values of a QNode with a tensor product of observables, its .expval(self, observable, wires, par) method would simply get the observables as a list in the observable argument. I understand that this was not super elegant, but it allowed the simple implementation in the example code below, which runs fine with v0.35.1 (up to a strange difference in the shape of the returned expectation value that cannot be fixed, even by explicitely returning a float from expval()...)

from typing import Callable, Sequence

import pennylane as qml
try:
    from pennylane.transforms import transform
except ImportError: # pennylane < 0.31
    from pennylane.transforms import batch_transform as transform

from pennylane.tape import QuantumTape

from pennylane import Device
from pennylane import numpy as np

class MyDevice(Device):
    author = "Christian Gogolin"
    pennylane_requires = ">=0.29.1"
    name = "MyDevice"
    short_name = "my.device"
    version = "0.1.0"

    _capabilities = {
        "tensor_observables": True,
    }

    def __init__(self, wires, *args, **kwargs):
        if wires != 2:
            raise NotImplementedError()
        self._state = None
        super().__init__(wires, *args, **kwargs)

    def apply(self, operation, wires, par):
        raise NotImplementedError(f"operation={operation}")

    def expval(self, observable, wires, par):
        if isinstance(observable, list):
            observables = observable
            for observable in observables:
                if observable != "PauliZ":
                    raise NotImplementedError()
            return np.prod([self.expval(observable, wire, par) for observable, wire in zip(observables, wires)])

        if observable != "PauliZ":
            raise NotImplementedError()
        return float(np.real(np.dot(self._state.conj().T, qml.matrix(qml.PauliZ(wires), wire_order=qml.wires.Wires([0, 1])).dot(self._state))).item())

    @property
    def observables(self):
        return ["PauliZ"]

    @property
    def operations(self):
        return []

    def reset(self):
        self._state = np.array([1., 0., 0., 0.])

results = list()
for dev in [qml.device("default.qubit", wires=2), MyDevice(wires=2)]:

    @qml.qnode(dev)
    def qnode():
        return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

    print(dev)
    print(qml.draw(qnode)())
    res = qnode()
    print("res:", res, type(res))
    results.append(res)

assert np.allclose(*results)

With 0.36.0 I get:

pennylane._device.DeviceError: Observable Prod not supported on device my.device

First, I don't understand why my device is asked to evaluate a Prod observable even though my QNode contains a Tensor observable, which probably could have been handled correctly by the code in _device.py:995.

Leaving this aside (tensor products are also products....): How am I supposed to modify my device code to support (tensor) product observables? Thanks!

Implementation

No response

How important would you say this feature is?

3: Very important! Blocking work.

Additional information

Name: PennyLane Version: 0.36.0 Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network. Home-page: https://github.com/PennyLaneAI/pennylane Author: Author-email: License: Apache License 2.0 Location: /fs/home/cvjjm/.conda/envs/hfak/lib/python3.9/site-packages Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions Required-by: PennyLane_Lightning

Platform info: Linux-4.18.0-372.32.1.el8_6.x86_64-x86_64-with-glibc2.28 Python version: 3.9.0 Numpy version: 1.23.5 Scipy version: 1.13.0 Installed devices:

albi3ro commented 4 months ago

I have a fairly hacky patch, but it might be sufficient for your purposes. A better fix might involve with copying and updating Device.execute to change the call signature for expval.

class ProdPatch(qml.operation.Observable):

    def __init__(self, prod_op):
        self.prod_op = prod_op

    @property
    def name(self):
        return [op.name for op in self.prod_op]

    @property
    def parameters(self):
        return self.prod_op.parameters

    @property
    def wires(self):
        return self.prod_op.wires

class MyDevice(Device):
    author = "Christian Gogolin"
    pennylane_requires = ">=0.29.1"
    name = "MyDevice"
    short_name = "my.device"
    version = "0.1.0"

    _capabilities = {
        "tensor_observables": True,
    }

    @property
    def observables(self):
        return ["PauliZ", "Prod"]

    def supports_observable(self, observable):
        if isinstance(observable, list):
            return all(self.supports_observable(ob) for ob in observable)
        return super().supports_observable(observable)

    def execute(self, queue, observables, parameters=None, **kwargs):
        new_measurements = []
        for mp in observables:
            if mp.obs and isinstance(mp.obs, qml.ops.Prod):
                new_measurements.append( type(mp)(obs=ProdPatch(mp.obs)) )
            else:
                new_measurements.append(mp)
        return super().execute(queue, new_measurements, parameters=parameters, **kwargs)

    def __init__(self, wires, *args, **kwargs):
        if wires != 2:
            raise NotImplementedError()
        self._state = None
        super().__init__(wires, *args, **kwargs)

    def apply(self, operation, wires, par):
        raise NotImplementedError(f"operation={operation}")

    def expval(self, observable, wires, par):
        if isinstance(observable, list):
            observables = observable
            for observable in observables:
                if observable != "PauliZ":
                    raise NotImplementedError()
            return np.prod([self.expval(observable, wire, par) for observable, wire in zip(observables, wires)])

        if observable != "PauliZ":
            raise NotImplementedError()
        return float(np.real(np.dot(self._state.conj().T, qml.matrix(qml.PauliZ(wires), wire_order=qml.wires.Wires([0, 1])).dot(self._state))).item())

    @property
    def operations(self):
        return []

    def reset(self):
        self._state = np.array([1., 0., 0., 0.])

I think this case actually helps demonstrate why we went though all the effort to both change operator arithmetic and to move to a new device interface.

1) Tensor.name broke the interface set out in Operator. Every other operator had it's name be a string, but Tensor had its name be a list of strings. This exception can be difficult to anticipate and account for.

2) A name or type by itself can be insufficient to determine whether or not something is supported. For example, a hamiltonian containing a Hermitian. The set of strings might say the device supports Hamiltonian, but then it won't be able to support a Hamiltonian that contains a Hermitian.

3) Legacy devices (qml.Device) rely on things hardcoded into an "abstract base class" that make it rather difficult to extend and customize.

We can add in some patches to qml.Device if that would help, but those wouldn't be released till our next release in July.

cvjjm commented 4 months ago

Thanks for going through the trouble of proposing a workaround!!!

I agree that it is rather ugly, but it may help me not waste too much time on code that I anyway plan to port over to the new device API...

I reported this here just because it was a bit annoying to experience breaking changes to a part of the codebase that is deprecated and which I had therefore hoped would not change, but I understand that this is due to the coupling between the legacy devices and the operator arithmetic.

CatalinaAlbornoz commented 4 months ago

Thank you for reporting this here @cvjjm! It does help us a lot to receive your feedback, both the good and bad experiences.