PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.34k stars 598 forks source link

Problems with the parameter shift hessian #1258

Closed cvjjm closed 3 years ago

cvjjm commented 3 years ago

The parameter shift Hessian does not seem to work reliably in all cases. It is difficult to come up with a very minimal example demonstrating the problem as I do not yet fully understand what what is going wrong. But I managed to condense it down to the following "minimal" example that demonstrates the main bug and shows a few minor problems. Hopefully, e.g., @josh146 can take it from there and get to the bottom of this.

The main problem is that the Autograd Hessian of a function involving a parameter-shift QNode does not agree with a finite difference approximation of the Hessian or the "best" method Hessian (the later two do agree to the expected accuracy).

There are a few additional minor issues that can be triggered by modifying the code according to the comments: The first being an IndexError when a gate is commented from the ansatz, making the result independent of one of the paramters. The second being a error that occurs when trying to draw a circuit during Hessian computation.

Here is the example:

import pennylane as qml
import numpy as np
import copy
from pennylane.utils import _flatten, unflatten

def finite_diff_hessian(energy, params, shift=1e-4, **kwargs):
    """Computes a finite difference approximation to the Hessian"""
    flat_params = copy.deepcopy(np.array(list(_flatten(params))))
    hessian = np.zeros((len(flat_params), len(flat_params)))
    for idx1 in range(len(flat_params)):
        params_shift1 = np.array([shift if idx==idx1 else 0. for idx in range(len(flat_params))])
        for idx2 in range(len(flat_params)):
            params_shift2 = np.array([shift if idx==idx2 else 0. for idx in range(len(flat_params))])

            energy11 = energy(unflatten(flat_params+params_shift1+params_shift2, params))
            energy01 = energy(unflatten(flat_params+params_shift1, params))
            energy10 = energy(unflatten(flat_params+params_shift2, params))
            energy00 = energy(params)
            hessian[idx1, idx2] = (energy11 - energy01 - energy10 + energy00)/(shift**2)
    #hessian = 0.5*(hessian + hessian.T) # symmetrization is not necssary
    return hessian

nwires = 8
wires = range(nwires)
dev = qml.device('default.qubit', wires=nwires)

def make_cost(diff_method='parameter-shift'):
    @qml.template
    def ansatz(params, wires=wires):
        qml.RY(1.5, wires=[6])
        qml.CNOT(wires=[6, 4])
        qml.CNOT(wires=[6, 7])
        qml.CNOT(wires=[4, 5])
        qml.SWAP(wires=[2, 4])
        qml.SWAP(wires=[3, 5])
        qml.CRY(params[0], wires=[2, 4])
        qml.SWAP(wires=[0, 2])
        qml.SWAP(wires=[1, 3])
        qml.CRY(params[1], wires=[0, 2]) # commenting this produces: "IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed"

    def run_circuit(params, ob, wires=wires):
        @qml.qnode(dev, interface="autograd", diff_method=diff_method) #without diff_method='parameter-shift' this works!
        def circ(params, wires=wires):
            ansatz(params, wires=wires)
            return qml.expval(ob)
        ret = circ(params, wires=wires)
        #print(circ.draw()) # produces "TypeError: unsupported format string passed to ArrayBox.__format__" Maybe another bug?
        return ret

    def circuit(params, wires=wires):
        return run_circuit(params, qml.PauliX(0) @ qml.PauliX(1) @ qml.PauliZ(2) \
                           @ qml.PauliZ(3) @ qml.PauliZ(4) @ qml.PauliZ(5) @ qml.PauliX(6) @ qml.PauliX(7), wires=wires)

    return circuit

def energy(params):
    return make_cost(diff_method='parameter-shift')(params)

def energy2(params):
    return make_cost(diff_method='best')(params)

params = np.array([-0.02, 0.4])

fd_hessian = finite_diff_hessian(energy, params)
pl_hessian = qml.jacobian(qml.grad(energy))(params)
pl_hessian2 = qml.jacobian(qml.grad(energy2))(params)

assert np.allclose(fd_hessian.ravel(), pl_hessian2.ravel(), rtol=0, atol=1e-5), "best method Hessian did not match.\nfd_hessian:\n{}\npl_hessian:\n{}\ndifference:\n{}".format(fd_hessian, pl_hessian, fd_hessian.ravel() - pl_hessian.ravel())
print("best method Hessians did match")
assert np.allclose(fd_hessian.ravel(), pl_hessian.ravel(), rtol=0, atol=1e-5), "parameter-shift Hessian did not match.\nfd_hessian:\n{}\npl_hessian:\n{}\ndifference:\n{}".format(fd_hessian, pl_hessian, fd_hessian.ravel() - pl_hessian.ravel())
print("parameter-shift Hessians did match")

The expected output is

best method Hessians did match
parameter-shift Hessians did match

instead I am getting

best method Hessians did match
AssertionError                            Traceback (most recent call last)
<ipython-input-1-d59f46ef7659> in <module>
     70 assert np.allclose(fd_hessian.ravel(), pl_hessian2.ravel(), rtol=0, atol=1e-5), "best method Hessian did not match.\nfd_hessian:\n{}\npl_hessian:\n{}\ndifference:\n{}".format(fd_hessian, pl_hessian, fd_hessian.ravel() - pl_hessian.ravel())
     71 print("best method Hessians did match")
---> 72 assert np.allclose(fd_hessian.ravel(), pl_hessian.ravel(), rtol=0, atol=1e-5), "parameter-shift Hessian did not match.\nfd_hessian:\n{}\npl_hessian:\n{}\ndifference:\n{}".format(fd_hessian, pl_hessian, fd_hessian.ravel() - pl_hessian.ravel())
     73 print("parameter-shift Hessians did match")

AssertionError: parameter-shift Hessian did not match.
fd_hessian:
[[-0.24439083 -0.00049429]
 [-0.00049429 -0.24438817]]
pl_hessian:
[[-4.83893333e-01 -4.95420897e-04]
 [-4.95420897e-04 -5.87862185e-01]]
difference:
[2.39502503e-01 1.12740152e-06 1.12740152e-06 3.43474020e-01]

Maybe this has to do with the non standard grad recipe of CRY?

josh146 commented 3 years ago

Hi @cvjjm -- good catch.

Maybe this has to do with the non standard grad recipe of CRY?

Yes, this is most likely it; the parameter-shift rule for Hessians was adapted from Estimating the gradient and higher-order derivatives on quantum hardware, which predates the 4-term rule!

We are soon to release a bugfix 0.14.1 version, including fixes for #1242. Perhaps the best plan is therefore:

cvjjm commented 3 years ago

Replacing qml.CRY with qml.CRY.decomposition in the above code works around the problem, which is corroborating the suspicion that the non-standard shift rules are the culprit. It would be good to add a test that covers this case.

Also please have a look at the other two issues I pointed out above.

josh146 commented 3 years ago

It would be good to add a test that covers this case.

Definitely, will do this ASAP.

Also please have a look at the other two issues I pointed out above.

Will investigate this further!

josh146 commented 3 years ago

Hi @cvjjm! I've had some time to explore this more deeply, here are my findings:

In the specific case of the controlled rotation gates, the implementation of the parameter-shift Hessian is hardcoded to follow the results from https://arxiv.org/abs/2008.06517. One short-term solution could be to expand the tape around the controlled rotation gates inside the parameter-shift Hessian logic. Unfortunately, the controlled rotation decomposition are not 1-1 and involve classical processing. I attempted to implement this, but the complexity spiraled significantly. Likely the best solution is to:

I've implemented the better error handling in #1260, ready for a small bugfix release.

# commenting this produces: "IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed

This was a great edge case find. It turns out that the vector-Hessian product that was being returned to Autograd had the wrong number of dimensions in the edge case where there were unused QNode parameters. I've fixed this in #1260.

#print(circ.draw()) # produces "TypeError: unsupported format string passed to ArrayBox.__format__" Maybe another bug?

I investigated this as well, but it turns out that the issue here is due to qnode.draw() being deprecated. If you use the new drawing function,

qml.draw(circ)(params)

this works as expected. Recently, Python changed how deprecation warnings are displayed for end-users, so we're working to make sure our deprecation warnings are fully accessible again in #1211!