PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
101 stars 26 forks source link

Adding support for calling `value_and_grad` on non-scalar inputs #859

Closed paul0403 closed 5 days ago

paul0403 commented 1 week ago

Before submitting

Please complete the following checklist when submitting a PR:

When all the above are checked, delete everything above the dashed line and fill in the pull request template.


Context: Bug fix: currently catalyst.value_and_grad only works on scalar inputs, from the work done in #804. This PR extends to arbitrary pytree shapes.

Description of the Change:

Benefits: catalyst.value_and_grad now can take in functions whose numerical arguments are not scalars

Possible Drawbacks:

Related GitHub Issues: closes #841 [sc-66764]

josh146 commented 1 week ago

Thanks @paul0403! This PR should be made to the v0.7.0-rc branch, since it's a bugfix to a release feature :) cc @rauletorresc

rauletorresc commented 1 week ago

Thanks @paul0403! This PR should be made to the v0.7.0-rc branch, since it's a bugfix to a release feature :) cc @rauletorresc

Thanks for the hint! I have rebased it with respect to v0.7.0-rc

josh146 commented 1 week ago

Also dow we plan to accept pytree values as return for example {"result": x}? It does not seem very useful as we are limited to one value returned.

For return I don't mind either way, but definitely important to support as input (as I've seen JAX optimizers that require this)

josh146 commented 1 week ago

Although I think @jax.value_and_grad works with returned pytrees, so I think it's good to match it's behaviour.

josh146 commented 6 days ago

@paul0403 don't forget to add a link to this PR to the changelog! You don't need to create a new changelog entry, simply add this PR to the existing value_and_grad changelog entry.

paul0403 commented 5 days ago

Nice job! Just a question about the assertions. Also dow we plan to accept pytree values as return for example {"result": x}? It does not seem very useful as we are limited to one value returned.

Currently catalyst only supports the differentiated function having a single float as return; this is not a catalyst.value_and_grad issue, as it is the behavior of catalyst.grad as well.