patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.39k stars 123 forks source link

Forcing solver to stay in given region #200

Open grfrederic opened 1 year ago

grfrederic commented 1 year ago

I'm solving a system of ODEs that simulates concentrations of some substances. I know that the concentrations have to be real numbers in the range [0, 1]. Normally, after each step, I would simply clamp the values to be in that region (since values outside of that range can mess with the simulation). Is there a way to achieve this nicely with diffrax?

I'm trying out a pretty wide range of parameters (since I'm using the simulations for MCMC), so ramping up the accuracy and step sizes to handle the outliers seems wasteful.

patrick-kidger commented 1 year ago

There's a couple of ways you could do this. The first, as you say, is to just clamp the values to the desired region. You can do this by wrapping the solver:

class ClampSolver(dfx.AbstractSolver):
    solver: dfx.AbstractSolver
    clamp: Callable

    def step(...):
        y1, ... = self.solver.step(step)
        y1 = self.clamp(y1)
        return y1, ...

    # also forward any other methods to `self.solver` as needed

clamp = lambda y: jnp.clip(y, a_min=0, a_max=1)
solver = ClampSolver(dfx.Tsit5(), clamp)

Another way (that will give more accurate solutions) is to have the step size controller reject any steps that are out of bounds (which we do by setting the error estimate to infinity). This assumes you're using an adaptive stepsize controller such as PIDController.

class InBoundsSolver(dfx.AbtractSolver):
    solver: dfx.AbstractSolver
    out_of_bounds: Callable

    def step(...):
        y1, y_error, ... = self.solver.step(...)
        oob = self.out_of_bounds(y1)
        keep = lambda y: jnp.where(oob, jnp.inf, y)
        y_error = jax.tree_util.tree_map(keep, y_error)
        return y1, y_error, ...

    # also forward any other methods to `self.solver` as needed

out_of_bounds = lambda y: (y < 0) | (y > 1)
solver = InBoundsSolver(dfx.Tsit5(), out_of_bounds)

Eventually I intend to provide built-in support for this kind of thing, so I'm also going to mark this as a feature request.

grfrederic commented 1 year ago

Thank you! I really like the second solution.