PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.32k stars 595 forks source link

[BUG] Metric tensor with only some of multiple QNode argument does not work in Torch #1991

Open dwierichs opened 2 years ago

dwierichs commented 2 years ago

Expected behavior

In the following QNode, we have two arguments:

dev = qml.device("default.qubit", wires=4)

@qml.qnode(dev, interface="torch")
def circuit(x, y):
    qml.Rot(*x[0], wires=0)
    qml.Rot(*x[1], wires=1)
    qml.Rot(*x[2], wires=2)
    qml.CNOT(wires=[0, 1])
    qml.CNOT(wires=[1, 2])
    qml.CNOT(wires=[2, 0])
    qml.RY(y[0], wires=0)
    qml.RY(y[1], wires=1)
    qml.RY(y[0], wires=2)
    return qml.expval(qml.PauliZ(0)@qml.PauliZ(2))

x = torch.tensor([[0.2, 0.4, -0.1], [-2.1, 0.5, -0.2], [0.1, 0.7, -0.6]], dtype=torch.float64)
y = torch.tensor([1.3, 0.2], requires_grad=True, dtype=torch.float64)

The expected behaviour for the metric_tensor now would be a single array output with the metric tensor w.r.t y, namely

>>> qml.metric_tensor(circuit)(x, y)
tensor([[ 0.2550, -0.0709],
            [-0.0709,  0.2495]], dtype=torch.float64)

Actual behavior

The computation raises an error (see below). This is as far as I can tell because the classical Jacobian is computed for all QNode in arguments in Torch. That is, without argnum as keyword, the QNode wrapper does not know which of the classical Jacobians to use to contract with the tape metric tensor(s). Another reason to implement #1880 soonish.

Indeed, when setting hybrid=False in metric_tensor, we get the correct metric tensor for the three gates that are controlled by entries of the QNode argument y:

>>> qml.metric_tensor(circuit, hybrid=False)(x, y)
tensor([[ 0.2498,  0.0032, -0.1224],
            [ 0.0032,  0.2495, -0.0741],
            [-0.1224, -0.0741,  0.2500]], dtype=torch.float64)

Additional information

No response

Source code

No response

Tracebacks

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_37953/3753152100.py in <module>
     19 # y = pnp.array([1.3, 0.2], requires_grad=True)
     20 # print(qml.adjoint_metric_tensor(circuit)(y))
---> 21 print(qml.metric_tensor(circuit)(x, y))
     22 # circuit(x, y)
     23 # circuit.qtape.get_parameters()

~/venvs/xanadu/lib/python3.8/site-packages/PennyLane-0.20.0.dev0-py3.8.egg/pennylane/transforms/metric_tensor.py in wrapper(*args, **kwargs)
    333                 for c in cjac:
    334                     if c is not None:
--> 335                         _mt = qml.math.tensordot(mt, c, axes=[[-1], [0]])
    336                         _mt = qml.math.tensordot(c, _mt, axes=[[0], [0]])
    337                         metric_tensors.append(_mt)

~/venvs/xanadu/lib/python3.8/site-packages/PennyLane-0.20.0.dev0-py3.8.egg/pennylane/math/multi_dispatch.py in tensordot(tensor1, tensor2, axes)
    280     """
    281     interface = _multi_dispatch([tensor1, tensor2])
--> 282     return np.tensordot(tensor1, tensor2, axes=axes, like=interface)
    283 
    284 

~/venvs/xanadu/lib/python3.8/site-packages/autoray/autoray.py in do(fn, like, *args, **kwargs)
     82         backend = infer_backend(like)
     83 
---> 84     return get_lib_fn(backend, fn)(*args, **kwargs)
     85 
     86 

~/venvs/xanadu/lib/python3.8/site-packages/PennyLane-0.20.0.dev0-py3.8.egg/pennylane/math/single_dispatch.py in _tensordot_torch(tensor1, tensor2, axes)
    461     if not semantic_version.match(">=1.10.0", torch.__version__) and axes == 0:
    462         return torch.outer(tensor1, tensor2)
--> 463     return torch.tensordot(tensor1, tensor2, axes)
    464 
    465 

~/venvs/xanadu/lib/python3.8/site-packages/torch/functional.py in tensordot(a, b, dims, out)
    930 
    931     if out is None:
--> 932         return _VF.tensordot(a, b, dims_a, dims_b)  # type: ignore[attr-defined]
    933     else:
    934         return _VF.tensordot(a, b, dims_a, dims_b, out=out)  # type: ignore[attr-defined]

RuntimeError: contracted dimensions need to match, but first has size 3 in dim -1 and second has size 12 in dim 0

System information

Name: PennyLane
Version: 0.20.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: /home/david/venvs/xanadu/lib/python3.8/site-packages/PennyLane-0.20.0.dev0-py3.8.egg
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, scipy, semantic-version, toml
Required-by: pennylane-qulacs, PennyLane-Qchem, PennyLane-Lightning
Platform info:           Linux-5.4.0-91-generic-x86_64-with-glibc2.29
Python version:          3.8.10
Numpy version:           1.21.0
Scipy version:           1.7.1
Installed devices:
- qulacs.simulator (pennylane-qulacs-0.16.0)
- lightning.qubit (PennyLane-Lightning-0.18.0)
- default.gaussian (PennyLane-0.20.0.dev0)
- default.mixed (PennyLane-0.20.0.dev0)
- default.qubit (PennyLane-0.20.0.dev0)
- default.qubit.autograd (PennyLane-0.20.0.dev0)
- default.qubit.jax (PennyLane-0.20.0.dev0)
- default.qubit.tf (PennyLane-0.20.0.dev0)
- default.qubit.torch (PennyLane-0.20.0.dev0)
- default.tensor (PennyLane-0.20.0.dev0)
- default.tensor.tf (PennyLane-0.20.0.dev0)

josh146 commented 2 years ago

@dwierichs thanks for finding this! I'm just getting caught up after vacation, and I'm a little out of the loop, so just two quick questions:

dwierichs commented 2 years ago

Hi @josh146

  1. Which fix do you mean here? I think we don't have one for JAX (as the forward passes don't work without argnum...)
  2. Yes, I think this would be the cleanest and simplest way to get a consistent behaviour. I don't see much reason to go for any other hotfix.
josh146 commented 2 years ago

Which fix do you mean here? I think we don't have one for JAX (as the forward passes don't work without argnum...)

Oh, I had a recollection that classical_jacobian was patched for JAX to only support non-trainable parameters? I could be wrong though!

dwierichs commented 2 years ago

Ah, yes, kind of :D By now this is patched again and both,trainable_only=True and False are supported again. In any case I feel like adding argnum is the better way forward here because it avoids patchwork treatment of the metric tensor behaviour for different interfaces? On the side, I noticed that the same bug exists for JAX.

josh146 commented 2 years ago

On the side, I noticed that the same bug exists for JAX.

oof 😬

I am in agreement, lets prioritize argnum

dwierichs commented 1 year ago

This bug is fixed via argnum for JAX, but not for Torch.