Open jecampagne opened 1 year ago
I just added some tests for this https://github.com/f0uriest/interpax/blob/e3c7455571265561dba40d4d4e3de3583def561b/tests/test_interpolate.py#L338
Depending on what you want to differentiate with respect to it will change a bit, let me know if it's not clear.
Nice job indeed. it would be nice to provide an example how to compute jax (grad, jacfwd, jacrev) on the interpolated function. Thanks