SharmaLlama / ticktack

An open-source carbon box model implementation built on JAX.
https://sharmallama.github.io/ticktack
MIT License
11 stars 3 forks source link

fast autodiff #11

Closed benjaminpope closed 2 years ago

benjaminpope commented 3 years ago

We are limited in our ability to use HMC by the fact that the derivatives through odeint are very slow, and we can't use the Laplace approximation efficiently because it doesn't support forwards-mode autodiff.

We should contact the Jax team or anyone else we know who is a Jax ninja about forwards-mode autodiff. (https://github.com/google/jax/issues/7401)

As for why even the easy derivatives are slow, this seems to be because the ODE is stiff. We should consider whether there are less-stiff approximations to the ODE (eg treat the ccean fluxes as constant?), or reparametrizations, which might be faster.
(https://github.com/google/jax/issues/3993)

We could also consider whether there are stiff integrators implemented/implementable in Jax that could perform better. Or should we switch to Julia for the numerical parts?

benjaminpope commented 3 years ago

One easy suggestion: compare the performance of the gradients on each presaved model.

Jordan-Dennis commented 2 years ago

So for clarity I have been working on this issue today and I think that it is not possible to resolve with the DP5 solver. Changing the number of arguments that go into the dc14 family of functions and differentiating with respect to a single argument pair takes ~10s. over sixteen values this is in the minutes which is the time that @Jordan-Dennis and @qingyuanzhang3 found for computing the hessian.

Weirdly, the hessians have converged in ~1s before (see the slack). This discrepancy is disturbing but consistent, although I cannot find a good explanation for what is causing the difference when at a high level the code behaves nearly identically.

Jordan-Dennis commented 2 years ago

Using the linear solver this time can be dropped down to ~100ms range: image This is an improvement but we are faced with the issue that the linear solver does not achieve the same accuracy as the DP5.

Jordan-Dennis commented 2 years ago

So the bin_data function seems to be a major culprit for the autodiff slowdown. The function itself has fast evaluations so this is unusual.

Jordan-Dennis commented 2 years ago

Having investigated the bin_data function the resolution of the time sampling was found to be the primary culprit. Reducing this from oversample=1008 to oversample=48 the hessians now run in: image

Jordan-Dennis commented 2 years ago

Having investigated further the bin_data function was generating a very large array internally that was congesting the autodiff. The speed could be significantly reducing the oversample parameter.