patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

Zero gradient when using jnp.piecewise inside an ODE #363

Open dflocher opened 7 months ago

dflocher commented 7 months ago

Hi, applying jax.grad to a function which uses diffrax to integrate a piecewise defined ODE, I observe that one partial derivative is unexpectedly zero. The ODE solver returns correct function values, just the gradient is wrong. I’m wondering whether this is a bug, or whether I’m doing something wrong. Thanks in advance! David

Example:

Consider the piecewise defined ODE

$$\frac{dy}{dt} = -k(t) \cdot y, \qquad y(0) = y_0, \qquad \mathrm{with} \ k(t) = \begin{cases} k_0, \ t \leq T \ 0, \ t > T \end{cases},$$

to which the solution reads

$$y(t) = y_0 \cdot \begin{cases} e^{-k_0 t}, \ t \leq T \ e^{-k_0 T}, \ t > T \end{cases}$$

I'm interested in the partial derivatives w.r.t. $T$ and $k_0$. In the code example below, I compare the gradient obtained from integrating the ODE using diffrax to the analytical solution and to a finite difference calculation.

import jax
import jax.numpy as jnp
import diffrax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)

def calc_y_analytically(params, t, y0):
    T, k0 = params
    return y0 * jnp.piecewise(t, [t <= T, t > T], [lambda x: jnp.exp(-k0*x), lambda x: jnp.exp(-k0*T)])

def calc_y_ode(params, t, y0):

    def ode(t, y, args):
        T, k0 = args
        k_of_t = jnp.piecewise(t, [t <= T, t > T], [k0, 0.0])
        d_y = -1 * k_of_t * y
        return d_y

    term = diffrax.ODETerm(ode)
    solver = diffrax.Tsit5()
    sol = diffrax.diffeqsolve(term, solver, t0=0.0, t1=t, dt0=0.00001, y0=y0, args=params, max_steps=200000)
    return sol.ys[0]

if __name__ == '__main__':

    params = jnp.array([1.0, 5.5])  # (T, k0)
    t = 1.2
    y0 = 100.0

    # calculate y(t) and the gradient w.r.t. T and k0 analytically
    y_ana, grads_ana = jax.value_and_grad(calc_y_analytically)(params, t, y0)

    # propagate y(0)=y0 until t by solving the ODE and calculate the gradient w.r.t. T and k0
    y_diff, grads_diff = jax.value_and_grad(calc_y_ode)(params, t, y0)

    # perform finite differences method on 0th parameter for verification
    eps = 1e-4
    param0 = params[0]
    params_plus = params.at[0].set(param0 + eps)
    params_minus = params.at[0].set(param0 - eps)

    y_ana_plus = calc_y_analytically(params_plus, t, y0)
    y_ana_minus = calc_y_analytically(params_minus, t, y0)
    part_deriv_ana = (y_ana_plus - y_ana_minus) / (2 * eps)

    y_diff_plus = calc_y_ode(params_plus, t, y0)
    y_diff_minus = calc_y_ode(params_minus, t, y0)
    part_deriv_diff = (y_diff_plus - y_diff_minus) / (2 * eps)

    print('\ny(t):')
    print('Analytical: {y:.6f}'.format(y=y_ana))
    print('Diffrax:    {y:.6f}'.format(y=y_diff))

    print('\nGradient:')
    print('Analytical: ' + str(grads_ana))
    print('Diffrax:    ' + str(grads_diff))

    print('\nPartial derivative w.r.t. parameter 0 via finite difference:')
    print('Analytical: {p:.6f}'.format(p=part_deriv_ana))
    print('Diffrax:    {p:.6f}'.format(p=part_deriv_diff))

prints out the following:

y(t):
Analytical: 0.408677
Diffrax:    0.408675

Gradient:
Analytical: [-2.24772429 -0.40867714]
Diffrax:    [ 0.         -0.40867537]

Partial derivative w.r.t. parameter 0 via finite difference:
Analytical: -2.247724
Diffrax:    -2.247712

I'm using Python 3.11.7, jax 0.4.23, jaxlib 0.4.23.dev20231223, diffrax 0.5.0, MacOS 14.2.1, x86_64, running on CPU

patrick-kidger commented 7 months ago

Ah, this is a known (although pretty obscure) limitation of using autodifferentiable differential equation solvers with discontinuous vector fields.

TL;DR: the solution

First of all, the solution: explicitly declare the jump time in the stepsize controller, typically by doing stepsize_controller=PIDController(..., jumps_ts=[T]). (I think you could also use StepTo here if you wanted a fixed step size rather than an adaptive one.)

What went wrong?

As for what's going on, we can explain this in a few different ways.

Why does the solution above work?

So how do we fix this? We've seen that writing diffeqsolve(..., t0=0, t1=t) doesn't work.

Our first insight into fixing this is to observe that if we had split this into diffeqsolve(..., t0=0, t1=T) and diffeqsolve(..., t0=jnp.minimum(t, T), t1=t), then we would have a well-behaved ODE on each piece.

In fact, go ahead and test this, and you'll get the expected gradient! Reasoning in terms of the computation graphs described above, we can see that the reason for this is that T now has a differentiable dependence in the computation graph.

So using PIDController(..., jump_ts=[T]) basically does exactly this: it means we make a numerical step right to that point. (And will also be faster to compile than the double-diffeqsolve approach -- which doubles the number of operations JAX has to compile.)

Can we do better?

This is an unfortunate user footgun! But I don't know of an automatic solution to this; so far as I know it may be an open question in the theory of autodifferentiation. (?)

Spitballing, I imagine this could maybe be solved by having the ODE solver try and detect when it thinks a jump has occured, if so to solve a root-finding problem to find the jump, and then use that in its step size control.

I think investigating this might be an interesting research question in autodifferentiation, for those curious enough to try :)

dflocher commented 7 months ago

Thank you, Patrick, for your detailed and instructive answer! Your proposed solution works fine. Best, David