patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.32k stars 121 forks source link

Support for uncertainty propagation? #438

Open nstarman opened 3 weeks ago

nstarman commented 3 weeks ago

It would be amazing to be able to specify uncertainties on ICs and have that propagate through to the solutions. I know this is a very difficult problem in general. Off the top of my head some challenges are:

  1. Gaussian uncertainties are hard, but operations on arbitrary (e.g. non-symmetric) distributions often don't even have an analytic form. This could be approximated to 1st order...
  2. What's the new API look like?
  3. How to support dense solutions?

Thankfully I think at least point 2 has a workable solution. @patrick-kidger, you've written quax to allow for array-ish objects in JAX. My suggestion would be to make a diffeqsolve(y0=) accept quax classes that handle the distribution and its propagation.

Point 1 still remains hard, but there's a still-useful starting point. The simplest "uncertainty" to support isn't even Gaussian but a simple lower and upper bound interval. That would be a good proof of concept but still useful! I know that the same result could be accomplished by doing diffeqsolve twice, but a) the unified API would be a convenience and b) we could hopefully subsequently implement Gaussian and more complex distributions.

To use the opening example from https://docs.kidger.site/diffrax/usage/getting-started/

from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController
from diffrax import Interval

vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=Interval(0.9, 1.1),  # note the Interval
                             saveat=saveat, stepsize_controller=stepsize_controller)

print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
print(sol.ys)  # Interval(...)  # IDK about the internals

As a related note, having quax classes would also enable a nice bundling of arrays of y0 into a MonteCarloMeasurement approximation of an uncertainty distribution:

sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=MCMeasurement(...),
                             saveat=saveat, stepsize_controller=stepsize_controller)

print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
print(sol.ys)  # MCMeasurement(...)
print(sol.ys.mean())  # DeviceArray([1.   , 0.368, 0.135, 0.0498])
print(sol.ys.std())  # DeviceArray([...])
patrick-kidger commented 3 weeks ago

I think this'd be really cool!

So as you've highlighted, Quax is the way to go about implementing something like this. The way this would work is to provide overloads for all primitives that diffeqsolve uses, and then call quax.quaxify(diffrax.diffeqsolve)(y0=Interval(...), ...).

In particular the point of a library like Quax is that this shouldn't require changing Diffrax at all.

Providing all of those overloads is probably fairly ambitious (and might stress-test just how load-bearing Quax really is 😅), but if you do something like that I'd love to see it!

nstarman commented 3 weeks ago

As a simple example:

class ExactMeasurement(quax.ArrayValue):
    array: ArrayLike = eqx.field(converter=jnp.asarray)

    def aval(self):
        shape = jnp.shape(self.array)
        dtype = jnp.result_type(self.array)
        return jax.core.ShapedArray(shape, dtype)

    def materialise(self):
        msg = "Refusing to materialise."
        raise ValueError(msg)

    @staticmethod
    def default(
        primitive: jax.core.Primitive,
        values: Sequence[ArrayLike | quax.Value],
        params: dict,
    ):
        raw_values: list[ArrayLike] = []
        for value in values:
            if eqx.is_array_like(value):
                raw_values.append(value)
            elif isinstance(value, ExactMeasurement):
                raw_values.append(value.array)
            elif isinstance(value, quax.Value):
                raise NotImplementedError
            else:
                raise AssertionError

        out = primitive.bind(*raw_values, **params)
        return (
            [ExactMeasurement(x) for x in out]
            if primitive.multiple_results
            else ExactMeasurement(out)
        )

vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0.0, 1.0, 2.0, 3.0])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

sol = quax.quaxify(diffeqsolve)(
    term,
    solver,
    t0=0,
    t1=3,
    dt0=0.1,
    y0=ExactMeasurement(1.0),
    saveat=saveat,
    stepsize_controller=stepsize_controller,
)

print(sol.ts)  # ExactMeasurement(array=f32[4])
print(sol.ys)  # ExactMeasurement(array=f32[4])

Obviously only having the default rule is too permissive since sol.ts is now an ExactMeasurement, but this shows it's possible.