Closed benjaminpope closed 2 years ago
One easy suggestion: compare the performance of the gradients on each presaved model.
So for clarity I have been working on this issue today and I think that it is not possible to resolve with the DP5
solver. Changing the number of arguments that go into the dc14
family of functions and differentiating with respect to a single argument pair takes ~10s
. over sixteen values this is in the minutes which is the time that @Jordan-Dennis and @qingyuanzhang3 found for computing the hessian.
Weirdly, the hessians have converged in ~1s
before (see the slack). This discrepancy is disturbing but consistent, although I cannot find a good explanation for what is causing the difference when at a high level the code behaves nearly identically.
Using the linear solver this time can be dropped down to ~100ms
range:
This is an improvement but we are faced with the issue that the linear solver does not achieve the same accuracy as the DP5
.
So the bin_data
function seems to be a major culprit for the autodiff slowdown. The function itself has fast evaluations so this is unusual.
Having investigated the bin_data
function the resolution of the time sampling was found to be the primary culprit. Reducing this from oversample=1008
to oversample=48
the hessians now run in:
Having investigated further the bin_data
function was generating a very large array internally that was congesting the autodiff. The speed could be significantly reducing the oversample
parameter.
We are limited in our ability to use HMC by the fact that the derivatives through
odeint
are very slow, and we can't use the Laplace approximation efficiently because it doesn't support forwards-mode autodiff.We should contact the Jax team or anyone else we know who is a Jax ninja about forwards-mode autodiff. (https://github.com/google/jax/issues/7401)
As for why even the easy derivatives are slow, this seems to be because the ODE is stiff. We should consider whether there are less-stiff approximations to the ODE (eg treat the ccean fluxes as constant?), or reparametrizations, which might be faster.
(https://github.com/google/jax/issues/3993)
We could also consider whether there are stiff integrators implemented/implementable in Jax that could perform better. Or should we switch to Julia for the numerical parts?