patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.32k stars 121 forks source link

How to get derivative of final time with respect to a variable #451

Closed nsteffen closed 1 week ago

nsteffen commented 1 week ago

Hello!

I am working on a problem where I want to propagate a trajectory using diffrax, use a discrete terminating event to end that trajectory, and then get the derivative of that final time with respect to a variable. Regardless of the problems I have tried to do this on, the gradient of the final time always ends up as zero. Below is an example where I model the simplified dynamics of a cannonball and I try to evaluate the gradient of the final time with respect to the coefficient of drag.

import jax
import jax.numpy as jnp
import diffrax as dfx
jax.config.update('jax_enable_x64', True)

def dynamics(t, y, CD):
    g = 9.80665 # m/s^2
    rho = 1.225 # kg/m^3
    S = 0.005 # m^2
    m = 1. # kg

    v = y[2]
    gam = y[3]

    D = (0.5*rho*v**2)*S*CD

    sin_gam = jnp.sin(gam)
    cos_gam = jnp.cos(gam)
    ydot = jnp.array([v*cos_gam,        # rdot
                      v*sin_gam,        # hdot
                      -D/m - g*sin_gam, # vdot
                      -(g*cos_gam)/v])  # gamdot

    return ydot

def event(state, **kwargs):
    h = state.y[1]
    return h < 0.

def obj(CD):
    t0 = 0.
    t1 = 20.
    dt0 = None
    y0 = jnp.array([0.,                 # r0
                    100.,               # h0
                    5.,                 # v0
                    jnp.deg2rad(45.)])  # gam0

    solver = dfx.Tsit5()
    stepsize_controller = dfx.ConstantStepSize()
    saveat = dfx.SaveAt(ts=jnp.linspace(t0, t1, 100))

    sol = dfx.diffeqsolve(dfx.ODETerm(dynamics),
                          solver,
                          t0,
                          t1,
                          0.1,
                          y0,
                          args=CD,
                          stepsize_controller=stepsize_controller,
                          max_steps=None,
                          saveat=saveat,
                          discrete_terminating_event=event,
                          adjoint=dfx.RecursiveCheckpointAdjoint(checkpoints=100)
                          )

    ts = sol.ts[jnp.where(jnp.isfinite(sol.ts))]    
    return ts[-1]

if __name__ == '__main__':
    CD = 0.5
    print(jax.grad(obj)(CD))

Is there a correct/better way to do this? Thanks in advance!

patrick-kidger commented 1 week ago

So this is because you have a "discrete" terminating event -- the event halts at the end of the step in which the event was triggered. As that's a discrete thing then there (correctly) is no gradient.

You may like to try #387, which is a more featureful approach to events. In particular this includes the ability to (a) have an event return a real number, for which the solve terminates where that number is zero, and (b) have that exact location determined using a root find.

Tagging @cholberg for visibility, but I believe that should give you gradients.

nsteffen commented 1 week ago

Just got to testing that out, and it looks like your suggestion worked! Thanks for the help; I'm excited for this feature to be part of the main code!