Support brainpy.math.defjvp and brainpy.math.XLACustomOp.defjvp, similar to jax.interpreters.ad.defjvp. But these operators support defining jvp rules for Primitive with multiple results.
See examples in test_ad_support.py
How Has This Been Tested
Types of changes
Bug fix (non-breaking change which fixes an issue)
New feature (non-breaking change which adds functionality)
Documentation (non-breaking change which updates documentation)
Breaking change (fix or feature that would cause existing functionality to change)
Code style (formatting, renaming)
Refactoring (no functional changes, no api changes)
Other (please describe here):
Checklist
[ ] Code follows the code style of this project.
[ ] Changes follow the CONTRIBUTING guidelines.
[ ] Update necessary documentation accordingly.
[ ] Lint and tests pass locally with the changes.
[ ] Check issues and pull requests first. You don't want to duplicate effort.
Description
brainpy.math.defjvp
andbrainpy.math.XLACustomOp.defjvp
, similar tojax.interpreters.ad.defjvp
. But these operators support defining jvp rules for Primitive with multiple results.test_ad_support.py
How Has This Been Tested
Types of changes
Checklist
Other information