Open anthonysmaldone opened 11 months ago
Thanks for opening this issue @anthonysmaldone . We'll take a look and get back to you soon.
Hi @anthonysmaldone! I've been looking at this bug, and I must admit it's a strange one. I don't quite have the answer, but I've made some progress. torch offers a function called set_detect_anomaly
, and it pointed me in the right direction. The issue appears to be when computing the gradient of a function that raises tensors to a power (which we do here when generating probabilities). With anomaly detection set to True, I get this error:
RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.
I've made a minimal example of what's happening without any PennyLane code, because the fundamental problem seems to lie in torch assumptions.
import torch
import numpy as np
torch.autograd.set_detect_anomaly(True)
class Net(torch.nn.Module):
def forward(self, x):
return torch.floor(x[0]) ** 2 # some function that uses pow and will return 0
# Create the quantum layer
classical_layer = Net()
# Training loop
steps = 100
target = torch.tensor(np.cos(np.pi / 4), dtype=torch.float32)
# Set up the optimizer
optimizer = torch.optim.Adam([torch.tensor([2.9348])], lr=0.1)
for step in range(steps):
optimizer.zero_grad()
output = classical_layer(torch.tensor([0.5], dtype=torch.float32, requires_grad=True))
prediction = torch.sqrt(output)
loss = torch.nn.functional.mse_loss(prediction, target)
loss.backward()
optimizer.step()
if step % 10 == 0:
print(f"Step {step}: loss = {loss.item()}")
That said, if you set anomaly detection to False in the above example, you won't get NaN values. This does suggest that there could be a fix in PennyLane to minimize the severity of this issue, but the root of it is from torch having a hard time computing the gradient with PowBackward0
. Some googling shows that people have had similar issues, but not quite the same afaik. Most people recommend making things slightly greater than zero - your masking does that in a more elegant way imo, so you've already got the idea 🕺
I'll keep digging into this bug, although I'm not sure what I'll try next. Hopefully my findings thus far help reveal some stuff. Another thing I'd note is that a different measurement function that doesn't involved powers (eg. qml.purity
) won't have this issue
Expected behavior
When post-processing the output of a quantum circuit that returns a torch layer, the gradient of the
torch.sqrt()
function at the point '0' should be defined as 0. According to the PyTorch docs:Actual behavior
While PyTorch correctly sets the gradients of the classical weights, the gradients of the quantum weights become
nan
.Additional information
Here is an arbitrary Pennylane/PyTorch circuit that demonstrates the issue. The circuit measures a qubit in the Pauli Z basis that has been placed in maximal superposition with a Hadamard gate (thus, guaranteeing a measurement of 0).
This outputs:
Instead of directly taking the square root, if you mask non-zero values, the issue is fixed while maintaining identical values:
Fixed circuit output:
Source code
No response
Tracebacks
No response
System information
Existing GitHub issues