PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.34k stars 602 forks source link

Gradient Issue in Pennylane with PyTorch: Handling sqrt on Zero for Quantum Parameters Fails #4845

Open anthonysmaldone opened 11 months ago

anthonysmaldone commented 11 months ago

Expected behavior

When post-processing the output of a quantum circuit that returns a torch layer, the gradient of the torch.sqrt() function at the point '0' should be defined as 0. According to the PyTorch docs:

If the function is defined, define the gradient at the current point by continuity (note that inf is possible here, for example for sqrt(0)). If multiple values are possible, pick one arbitrarily.

Actual behavior

While PyTorch correctly sets the gradients of the classical weights, the gradients of the quantum weights become nan.

Additional information

Here is an arbitrary Pennylane/PyTorch circuit that demonstrates the issue. The circuit measures a qubit in the Pauli Z basis that has been placed in maximal superposition with a Hadamard gate (thus, guaranteeing a measurement of 0).

import torch
import pennylane as qml
from pennylane.qnn import TorchLayer
import numpy as np

def SimpleQuantumLayer():
    n_qubits = 2
    dev = qml.device("default.qubit", wires=n_qubits)

    def circuit(inputs, weights):
        qml.Hadamard(wires=0)
        qml.RY(weights[0], wires=1)
        return qml.expval(qml.PauliZ(0))

    qlayer = qml.QNode(circuit, dev, interface="torch")
    weight_shapes = {"weights": (1, )}

    return TorchLayer(qlayer, weight_shapes)

# Create the quantum layer
quantum_layer = SimpleQuantumLayer()

# Training loop
steps = 100
target = torch.tensor([np.cos(np.pi / 4)], dtype=torch.float32)

# Set up the optimizer
optimizer = torch.optim.Adam(quantum_layer.parameters(), lr=0.1)

for step in range(steps):
    optimizer.zero_grad()

    # Forward pass
    output = quantum_layer(torch.tensor([1.0], dtype=torch.float32))

    output = torch.sqrt(output)
    prediction = output

    # Compute the loss (mean squared error)
    loss = torch.nn.functional.mse_loss(prediction, target)

    # Backward pass
    loss.backward()

    # Update the weights
    optimizer.step()

    if step % 10 == 0:
        print(f"Step {step}: loss = {loss.item()}")

# Print the learned angle
print(f"Learned angle: {list(quantum_layer.parameters())[0].item()}")

This outputs:

Step 0: loss = 0.4999999701976776
Step 10: loss = nan
Step 20: loss = nan
Step 30: loss = nan
Step 40: loss = nan
Step 50: loss = nan
Step 60: loss = nan
Step 70: loss = nan
Step 80: loss = nan
Step 90: loss = nan
Learned angle: nan

Instead of directly taking the square root, if you mask non-zero values, the issue is fixed while maintaining identical values:

mask = output > 0
output[mask] = torch.sqrt(output[mask])

Fixed circuit output:

Step 0: loss = 0.4999999701976776
Step 10: loss = 0.4999999701976776
Step 20: loss = 0.4999999701976776
Step 30: loss = 0.4999999701976776
Step 40: loss = 0.4999999701976776
Step 50: loss = 0.4999999701976776
Step 60: loss = 0.4999999701976776
Step 70: loss = 0.4999999701976776
Step 80: loss = 0.4999999701976776
Step 90: loss = 0.4999999701976776
Learned angle: 6.073381423950195

Source code

No response

Tracebacks

No response

System information

Name: PennyLane
Version: 0.33.1
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Lightning

Platform info:           Linux-5.15.120+-x86_64-with-glibc2.35
Python version:          3.10.12
Numpy version:           1.23.5
Scipy version:           1.11.3
Installed devices:

Existing GitHub issues

albi3ro commented 11 months ago

Thanks for opening this issue @anthonysmaldone . We'll take a look and get back to you soon.

timmysilv commented 10 months ago

Hi @anthonysmaldone! I've been looking at this bug, and I must admit it's a strange one. I don't quite have the answer, but I've made some progress. torch offers a function called set_detect_anomaly, and it pointed me in the right direction. The issue appears to be when computing the gradient of a function that raises tensors to a power (which we do here when generating probabilities). With anomaly detection set to True, I get this error:

RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.

I've made a minimal example of what's happening without any PennyLane code, because the fundamental problem seems to lie in torch assumptions.

import torch
import numpy as np

torch.autograd.set_detect_anomaly(True)

class Net(torch.nn.Module):
    def forward(self, x):
        return torch.floor(x[0]) ** 2  # some function that uses pow and will return 0

# Create the quantum layer
classical_layer = Net()

# Training loop
steps = 100
target = torch.tensor(np.cos(np.pi / 4), dtype=torch.float32)

# Set up the optimizer
optimizer = torch.optim.Adam([torch.tensor([2.9348])], lr=0.1)

for step in range(steps):
    optimizer.zero_grad()
    output = classical_layer(torch.tensor([0.5], dtype=torch.float32, requires_grad=True))
    prediction = torch.sqrt(output)
    loss = torch.nn.functional.mse_loss(prediction, target)
    loss.backward()
    optimizer.step()
    if step % 10 == 0:
        print(f"Step {step}: loss = {loss.item()}")

That said, if you set anomaly detection to False in the above example, you won't get NaN values. This does suggest that there could be a fix in PennyLane to minimize the severity of this issue, but the root of it is from torch having a hard time computing the gradient with PowBackward0. Some googling shows that people have had similar issues, but not quite the same afaik. Most people recommend making things slightly greater than zero - your masking does that in a more elegant way imo, so you've already got the idea 🕺

I'll keep digging into this bug, although I'm not sure what I'll try next. Hopefully my findings thus far help reveal some stuff. Another thing I'd note is that a different measurement function that doesn't involved powers (eg. qml.purity) won't have this issue