Closed dime10 closed 1 year ago
I don't think this is an MLIR bug, but rather a front-end bug. EDIT: Maybe both, since calculate signatures in the front-end is based on the MLIR algorithm.
I think this will fix it:
diff --git a/frontend/catalyst/utils/calculate_grad_shape.py b/frontend/catalyst/utils/calculate_grad_shape.py
index 6bbcbad..86cfb1a 100644
--- a/frontend/catalyst/utils/calculate_grad_shape.py
+++ b/frontend/catalyst/utils/calculate_grad_shape.py
@@ -108,7 +108,7 @@ def calculate_grad_shape(signature, indices):
diff_arg_shape.append(axis)
for y in signature.get_results():
- grad_res_shape = diff_arg_shape
+ grad_res_shape = diff_arg_shape.copy()
if Signature.is_tensor(y):
for axis in y.shape:
grad_res_shape.append(axis)
The following example fails compilation:
with the error message: