Closed paul0403 closed 5 days ago
Thanks @paul0403! This PR should be made to the v0.7.0-rc
branch, since it's a bugfix to a release feature :) cc @rauletorresc
Thanks @paul0403! This PR should be made to the
v0.7.0-rc
branch, since it's a bugfix to a release feature :) cc @rauletorresc
Thanks for the hint! I have rebased it with respect to v0.7.0-rc
Also dow we plan to accept pytree values as return for example {"result": x}? It does not seem very useful as we are limited to one value returned.
For return I don't mind either way, but definitely important to support as input (as I've seen JAX optimizers that require this)
Although I think @jax.value_and_grad
works with returned pytrees, so I think it's good to match it's behaviour.
@paul0403 don't forget to add a link to this PR to the changelog! You don't need to create a new changelog entry, simply add this PR to the existing value_and_grad
changelog entry.
Nice job! Just a question about the assertions. Also dow we plan to accept pytree values as return for example
{"result": x}
? It does not seem very useful as we are limited to one value returned.
Currently catalyst only supports the differentiated function having a single float as return; this is not a catalyst.value_and_grad
issue, as it is the behavior of catalyst.grad
as well.
Before submitting
Please complete the following checklist when submitting a PR:
[x] All new functions and code must be clearly commented and documented.
[x] Ensure that code is properly formatted by running
make format
. The latest version of black andclang-format-14
are used in CI/CD to check formatting.[x] All new features must include a unit test. Integration and frontend tests should be added to
frontend/test
, Quantum dialect and MLIR tests should be added tomlir/test
, and Runtime tests should be added toruntime/tests
.When all the above are checked, delete everything above the dashed line and fill in the pull request template.
Context: Bug fix: currently
catalyst.value_and_grad
only works on scalar inputs, from the work done in #804. This PR extends to arbitrary pytree shapes.Description of the Change:
differentiation.py
amdjax_primitives.py
forvalue_and_grad
ValueAndGradOp
, check the gradient result types against the callee input argument types from the right instead of from the left, since if the function have constant values, the cloned callee will take in these constant values as arguments at the start of the argument listvalue_and_grad
Benefits:
catalyst.value_and_grad
now can take in functions whose numerical arguments are not scalarsPossible Drawbacks:
Related GitHub Issues: closes #841 [sc-66764]