jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

forward-mode autodiff for odeint #7401

Open qingyuanzhang3 opened 3 years ago

qingyuanzhang3 commented 3 years ago

Hi Jax team,

We want to calculate hessians of a likelihood function involving an ode integration so that we can do variational inference. We are running into an issue with custom_vjp, which we don't understand how to fix. We have the impression that it is not implemented for odeint. Our package is called ticktack, which is distributed on PyPI. The dataset miyake12.csv is hosted on GitHub here. Do you have any advice? Can we implement this easily, or are there plans to do this for odeint?

A minimal example:

import ticktack
from ticktack import fitting

cbm = ticktack.load_presaved_model('Guttler14', production_rate_units = 'atoms/cm^2/s')
cf = fitting.CarbonFitter(cbm)
default_params = [775., 1./12, np.pi/2., 81./12]
cf.load_data('miyake12.csv')

g = jit(hessian(cf.log_prob))
g(default_params) 

We are getting the output,

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_31044/2345364585.py in <module>
----> 1 g(default_params)

    [... skipping hidden 49 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in log_prob(self, params)
    107         # call log_like and log_prior, for later MCMC
    108         lp = self.log_prior(params)
--> 109         pos = self.log_like(params)
    110         return lp + pos
    111 

    [... skipping hidden 25 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in log_like(self, params)
     92     def log_like(self, params):
     93         # calls dc14 and compare to data, (can be gp or gaussian loglikelihood)
---> 94         d_14_c = self.dc14(params)
     95 
     96         chi2 = jnp.sum(((self.d14c_data[:-1] - d_14_c)/self.d14c_data_error[:-1])**2)

    [... skipping hidden 25 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in dc14(self, params)
     78     def dc14(self, params):
     79     # calls CBM on production_rate of params
---> 80         burn_in = self.run(self.burn_in_time, params, self.steady_state_y0)
     81         d_14_c = self.run_D_14_C_values(self.time_data, self.time_oversample, params, burn_in[-1, :])
     82         return d_14_c - 22.72

    [... skipping hidden 25 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in run(self, time_values, params, y0)
     65     @partial(jit, static_argnums=(0,))
     66     def run(self, time_values, params, y0):
---> 67         burn_in, _ = self.cbm.run(time_values, production=self.miyake_event, args=params, y0=y0)
     68         return burn_in
     69 

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/ticktack.py in run(self, time_values, production, y0, args, target_C_14, steady_state_production)
    323 
    324         if USE_JAX:
--> 325             states = odeint(derivative, y_initial, time_values)
    326         else:
    327             states = odeint(derivative, y_initial, time_values)

~/.local/lib/python3.8/site-packages/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
    171 
    172   converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
--> 173   return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
    174 
    175 @partial(jax.jit, static_argnums=(0, 1, 2, 3))

    [... skipping hidden 25 frame]

~/.local/lib/python3.8/site-packages/jax/experimental/ode.py in _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args)
    177   y0, unravel = ravel_pytree(y0)
    178   func = ravel_first_arg(func, unravel)
--> 179   out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
    180   return jax.vmap(unravel)(out)
    181 

    [... skipping hidden 4 frame]

~/.local/lib/python3.8/site-packages/jax/experimental/ode.py in _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args)
    216 
    217 def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
--> 218   ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
    219   return ys, (ys, ts, args)
    220 

    [... skipping hidden 5 frame]

~/.local/lib/python3.8/site-packages/jax/interpreters/ad.py in _raise_custom_vjp_error_on_jvp(*_, **__)
    676 
    677 def _raise_custom_vjp_error_on_jvp(*_, **__):
--> 678   raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
    679                   "function.")
    680 custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
hawkinsp commented 3 years ago

The issue is that odeint has a custom VJP, but functions with custom VJPs cannot be used in forward-mode autodiff.

We have ideas about a custom transpose feature that would allow this.

Here's a possible workaround: Hessian is defined using forward-over-reverse Jacobians:

def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
            holomorphic: bool = False) -> Callable:
  return jacfwd(jacrev(fun, argnums, holomorphic), argnums, holomorphic)

There's no fundamental reason we need to use forward-mode autodiff here, but it is usually the more efficient choice.

You could define your own hessian that uses jacrev in place of the jacfwd, although it may not do good things to the computational complexity.

benjaminpope commented 3 years ago

Great point, should just do 2x jacrev for now. We don't have truly enormous matrices and only have to call it a couple of times for basic variational inference, so that might solve it for the moment.

But I wonder - there are definitely good reasons to be able to do forwards-mode autodiff for an ode. For instance, you might want to calculate the high-dimensional derivative of the whole output time series wrt a single input parameter, in order to get a Lyapunov exponent... or calculate large Hessians not just once for a Laplace approximation inference but repeatedly in certain optimizations. Are there fundamental reasons this wouldn't be possible to build this functionality?

froystig commented 3 years ago

there are definitely good reasons to be able to do forwards-mode autodiff for an ode

Indeed, we agree.

Are there fundamental reasons this wouldn't be possible to build this functionality?

Nothing fundamental – just a limitation of custom_vjp's implementation today. As @hawkinsp mentioned, we hope to upgrade this once we've introduced machinery for custom transposition (and possibly custom partial evaluation). In the meantime, there's also the possible workaround of creating an explicit odeint primitive and defining JVP, transposition, and partial evaluation rules for it.

benjaminpope commented 3 years ago

Cool, awesome. I think for now we'll give it a go with jacrev but love to know if odeint gets an upgrade. I'm really excited by the stuff you're doing.

dumanah commented 2 years ago

For those who is interested in this problem, i would suggest to have a look at diffrax. Not only forward-mode differentiation can be applied its ode solver, it also has many different type of ode solver's as euler, heun, dopri5, runga-kutta etc.