google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.78k stars 2.72k forks source link

Let's split the VJP for odeint into JVP and transpose #1927

Open duvenaud opened 4 years ago

duvenaud commented 4 years ago

One of the main motivations for adding defvjp to Jax was to add the adjoint method for taking gradients through ODE solutions. However, this doesn't fit well with Jax's elegant approach of defining jvp for all nonlinear operations, transpose for all linear ones, and deriving vjps automatically from those. It also means we don't have forward-mode autodiff for odeint in Jax.

I initially thought it might be impossible, but now I think we understand everything necessary to break the vjp for odeint into a jvp and a transpose. Me and @jessebett coded up the jvp for forward mode, which was pleasantly simple: https://github.com/duvenaud/jaxde/blob/master/jaxde/ode_jvp.py The basic idea is that it just integrates forward the jvp of the original dynamics.

As for the transpose, I think I've worked out the math. The basic idea is that although odeint isn't a linear function in general, the ODE that we'll need to transpose (the jvp of the original dynamics) is a first-order linear homogeneous ODE, and its solutions are linear in the initial conditions. Based on some identities in that document, it's not too hard to show that

transpose(lin_homo_odeint)(z0, f, t0, t1) = lin_homo_odeint(z1, -transpose(f), t1, t0)

That is to say, the transpose of the solution to a linear homogeneous ODE wrt its initial state is simply another linear homogeneous ODE with negated transposed dynamics, run backwards in time. As a sanity check, applying this transpose to jvp_odeint gives us the same algorithm as vjp_odeint.

So, I think we can now refactor the Jax implementation of the adjoint method! Here are the main steps needed to get the basic version working:

At this point we should be able to automatically get the vjp wrt the initial state of odeint. There are a few more steps to get the full-featured version:

If this project ends up being too hairy, we could warm up by trying to do the same thing for fixed-point solvers. Ultimately, we might be able to remove defvjp entirely, armed with the knowledge that even these relatively hairy vjps could be broken up and handled the standard way.

shoyer commented 4 years ago

I agree, this seems like a good idea. Thanks for writing out this plan!

transpose(lin_homo_odeint(z0, f, t0, t1)) = lin_homo_odeint(z1, -transpose(f), t1, t0)

I think this should be: transpose(lin_homo_odeint)(z0, f, t0, t1) = lin_homo_odeint(z1, -transpose(f), t1, t0)?

  • So something like lin_homo_odeint(nonlin_f, nonlin_z0, lin_f, lin_z0, t0, t1) that outputs both ODE's final states. Internally, it just needs to concatenate the two ODEs together, call odeint, and split them back apart. This function is only linear in its second output wrt lin_zo.

Just to be clear, the signatures here would look like nonlin_f(nonlin_z) and lin_f(nonlin_z, lin_z), with lin_f (typically wrapped call to jvp) guaranteed to be linear with respect to lin_z?

If this project ends up being too hairy, we could warm up by trying to do the same thing for fixed-point solvers.

JVP/transpose rules for fixed-point solves would definitely be a worthwhile addition, though I suspect you could do a pretty reasonable version of this in terms of the existing custom_root and custom_linear_solve.

custom_linear_solve is probably the best precedent for what writing this sort of operation in JAX would look like. It relies on the identity transpose(linear_solve)(matvec, b) == linear_solve(transpose(matvec), b).

shoyer commented 4 years ago
  • [ ] Work out the math for transposing a linear ODE wrt parameters and time

I think we can get away without this. In general, JAX only needs to needs transpose rules for linear arguments, but linear ODEs are only linear with respect to the initial condition, not the parameters or time. (This is a good example of why using transpose rules is convenient, because often you don't need to write them at all!)

duvenaud commented 4 years ago

Thanks for the detailed feedback!

I think this should be: transpose(lin_homo_odeint)(z0, f, t0, t1) = lin_homo_odeint(z1, -transpose(f), t1, t0)?

You're correct, fixed.

the signatures here would look like nonlin_f(nonlin_z) and lin_f(nonlin_z, lin_z), with lin_f (typically wrapped call to jvp) guaranteed to be linear with respect to lin_z?

Yes, that's what I had in mind. Although on second thought, we should probably combine these into one function for efficiency: d_nonlin, d_lin = nonlin_and_lin_f(nonlin_z, lin_z). This would allow us to match the efficient call to both the original dynamics and its jvp here: https://github.com/duvenaud/jaxde/blob/master/jaxde/ode_jvp.py#L31

linear ODEs are only linear with respect to the initial condition, not the parameters or time.

That's true in general, but in the case we're interested in, the ODE is also linear in parameters and time, since its solution is the jvp of odeint, which is linear by definition. I'm pretty sure the extra bookkeeping will be simple, since the adjoint dynamics for the initial state, parameters, and time can be still computed with a single call to vjp.

mattjj commented 4 years ago

Thanks so much for proposing this, @duvenaud ! I don't have anything substantive to add yet; I just wanted to share my excitement :D I can't wait to learn more about ODEs by working on this.

shoyer commented 4 years ago

linear ODEs are only linear with respect to the initial condition, not the parameters or time.

That's true in general, but in the case we're interested in, the ODE is also linear in parameters and time, since its solution is the jvp of odeint, which is linear by definition. I'm pretty sure the extra bookkeeping will be simple, since the adjoint dynamics for the initial state, parameters, and time can be still computed with a single call to vjp.

If you look at your JVP rule, tan_t0 and tan_t1 make a separate contribution into the sensitivity -- they don't appear in the augmented ODE but rather only on this line:

    # Sensitivities of y(t1) wrt t0 and t1
    jvp_t_total = (tan_t1 - tan_t0) * func(yt, t1, fargs)

I think this implies we don't need ODE transpose rules for these arguments, because we only need transpose rules for function arguments where the tangent appears in the JVP calculation.

Actually, this brings up a question with your derivation. If I understand correctly, the JVP rule for odeint should be a generalization of the Leibnitz integral rule: image

This suggests to me that the right sensitivity due to the limits should rather be something like:

   jvp_t_total = tan_t1 * func(yt, t1, fargs) - tan_t0 * func(y0, t0, fargs) 

What do you think? I am missing something here?

duvenaud commented 4 years ago

tan_t0 does appear in the dynamics, on this line, and the jvp depends on both of them, so I think they'll both need transpose rules. But as I said, I think the rules will be simple, and for now I think we should just ignore them.

I'm pretty sure all the jvp code is correct, both because it passes numerical tests, and because it matches the vjp conceptually. One bit of intuition is that helped me was realizing that moving t0 forward moves the entire integral forward.