Closed paul0403 closed 3 months ago
Nice catch @paul0403! I wonder if this is easy/straightforward or not to fix, but I guess it requires some digging first.
The culprit is when lowering value_and_grad to mlir, the shape of the grad is not computed, and only the type is passed in: https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/jax_primitives.py#L537
(compare this with grad lowering, where the shape is calculated: https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/jax_primitives.py#L470)
I am fixing this now.
manually closing since the PR tracks release branch instead of main
Expected:
but got: