patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.38k stars 124 forks source link

Approaching multi trajectory adaptive stepping #481

Open lockwo opened 1 month ago

lockwo commented 1 month ago

In line with some of the weak solvers we are working on to get a PR for in diffrax, there are a variety of adaptive methods that we are implementing. One of the schemes rely on estimating errors by looking at multiple trajectories (https://onlinelibrary.wiley.com/doi/abs/10.1002/pamm.200410005), like you estimate some quantity from simultaneous trajectories.

I wanted to think how to best integrate this into diffrax philosophically, since this code works as a wrapper on top of it, but isn't as trivial to implement in the framework itself. Since integrate.py conceptually works over a single trajectory, to get multiple the solution is usually just to vmap, so I was thinking of playing around inside that and making a unvmap version of the computations that we needed (but that seemed very hacky to define custom unvmaps). I was curious if you had thought about this more and had opinions on multi trajectory reliant adaptive schemes?

patrick-kidger commented 1 month ago

Hmm. You've got a couple of options I think. First of all would be to bundle multiple trajectories together into one gigantic vector field (with each piece independent of the others). Diffrax just sees a single integration like normal. This would mean that a batch of solves would get fairly gigantic (each batch element has its own 'inner batch' of trajectories). It would preserve batch independence, however.

The alternative would be to reach across the batch and explicitly create a cross-batch dependence. JAX provides tools to do this in the form of jax.lax.p{sum, ...}. Take a look at eqx.nn.BatchNorm for an example. Typically you name a particular vmap.