PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
124 stars 27 forks source link

Incorrect gradient return type deduction in finite-diff method #83

Closed dime10 closed 1 year ago

dime10 commented 1 year ago

The following example fails compilation:

import pennylane as qml
from catalyst import grad, qjit

@qml.qnode(qml.device("lightning.qubit", wires=1))
def func(p: float):
    x = qml.probs()
    y = p**2
    return x, y

qjit(grad(func, method="fd"))(0.1)

with the error message:

/tmp/tmpgwi9cyw0/func.nohlo.mlir:4:12: error: 'gradient.grad' op invalid result type: grad result at position 1 must be 'tensor<f64>' but got 'tensor<2xf64>'
    %0:2 = "gradient.grad"(%arg0) {callee = @func, diffArgIndices = dense<0> : tensor<1xi64>, finiteDiffParam = 9.9999999999999995E-8 : f64, method = "fd"} : (tensor<f64>) -> (tensor<2xf64>, tensor<2xf64>)
           ^
/tmp/tmpgwi9cyw0/func.nohlo.mlir:4:12: note: see current operation: %0:2 = "gradient.grad"(%arg0) {callee = @func, diffArgIndices = dense<0> : tensor<1xi64>, finiteDiffParam = 9.9999999999999995E-8 : f64, method = "fd"} : (tensor<f64>) -> (tensor<2xf64>, tensor<2xf64>)
erick-xanadu commented 1 year ago

I don't think this is an MLIR bug, but rather a front-end bug. EDIT: Maybe both, since calculate signatures in the front-end is based on the MLIR algorithm.

erick-xanadu commented 1 year ago

I think this will fix it:

diff --git a/frontend/catalyst/utils/calculate_grad_shape.py b/frontend/catalyst/utils/calculate_grad_shape.py
index 6bbcbad..86cfb1a 100644
--- a/frontend/catalyst/utils/calculate_grad_shape.py
+++ b/frontend/catalyst/utils/calculate_grad_shape.py
@@ -108,7 +108,7 @@ def calculate_grad_shape(signature, indices):
                 diff_arg_shape.append(axis)

         for y in signature.get_results():
-            grad_res_shape = diff_arg_shape
+            grad_res_shape = diff_arg_shape.copy()
             if Signature.is_tensor(y):
                 for axis in y.shape:
                     grad_res_shape.append(axis)