Open qingyuanzhang3 opened 3 years ago
The issue is that odeint has a custom VJP, but functions with custom VJPs cannot be used in forward-mode autodiff.
We have ideas about a custom transpose feature that would allow this.
Here's a possible workaround: Hessian is defined using forward-over-reverse Jacobians:
def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
holomorphic: bool = False) -> Callable:
return jacfwd(jacrev(fun, argnums, holomorphic), argnums, holomorphic)
There's no fundamental reason we need to use forward-mode autodiff here, but it is usually the more efficient choice.
You could define your own hessian
that uses jacrev
in place of the jacfwd
, although it may not do good things to the computational complexity.
Great point, should just do 2x jacrev for now. We don't have truly enormous matrices and only have to call it a couple of times for basic variational inference, so that might solve it for the moment.
But I wonder - there are definitely good reasons to be able to do forwards-mode autodiff for an ode. For instance, you might want to calculate the high-dimensional derivative of the whole output time series wrt a single input parameter, in order to get a Lyapunov exponent... or calculate large Hessians not just once for a Laplace approximation inference but repeatedly in certain optimizations. Are there fundamental reasons this wouldn't be possible to build this functionality?
there are definitely good reasons to be able to do forwards-mode autodiff for an ode
Indeed, we agree.
Are there fundamental reasons this wouldn't be possible to build this functionality?
Nothing fundamental – just a limitation of custom_vjp
's implementation today. As @hawkinsp mentioned, we hope to upgrade this once we've introduced machinery for custom transposition (and possibly custom partial evaluation). In the meantime, there's also the possible workaround of creating an explicit odeint
primitive and defining JVP, transposition, and partial evaluation rules for it.
Cool, awesome. I think for now we'll give it a go with jacrev
but love to know if odeint
gets an upgrade. I'm really excited by the stuff you're doing.
Hi Jax team,
We want to calculate hessians of a likelihood function involving an ode integration so that we can do variational inference. We are running into an issue with
custom_vjp
, which we don't understand how to fix. We have the impression that it is not implemented forodeint
. Our package is called ticktack, which is distributed on PyPI. The dataset miyake12.csv is hosted on GitHub here. Do you have any advice? Can we implement this easily, or are there plans to do this forodeint
?A minimal example:
We are getting the output,