patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

statefully evolving an auxiliary variable #462

Open jbial opened 1 month ago

jbial commented 1 month ago

Given any vector field f(t, y, args), I want to keep track of an auxiliary variable w that depends and evolves with y but non-differentiably in the sense that there is no vector field for w, it's just a state that gets updated every step call. For example, at every step of the integrator I would ideally like to be able to do something like the following pseudocode:

def update_w(t: float, y: Array, w: Array) -> Array:
    return jac(t, y)@w

def step(...):
    ....
    y1 = ....
    y_error = ...
    w = update_w(t, y1, w)
    return ((y1, w), y_error, dense, ...)

I'm currently trying to wrap an arbitrary solver with some success like so:

def wrap_solver(solver_class):

    class WrappedSolver(solver_class):

        def update_aux(self, t0, x, aux):
             ....
             return aux

        def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
            (y, y_error, dense_info, solver_state, solver_result) = super().step(terms, t0, t1, y0, args, solver_state, made_jump)
            x, aux = y
            new_aux = self.update_aux(t0, x, aux)
            return ((x, new_aux), y_error, dense_info, solver_state, solver_result)

    return WrappedSolver

But this seems very janky and I'm wondering if there's a cleaner way without exposing the step method. It's as if I need a custom term where one term is an ODETerm with control dt and the other is like a "StateTerm" also with control dt but it doesn't get integrated.

etienney commented 1 month ago

I'm also very interested as it is similar to my problem #447 It is not the most satisfying solutio because :

import diffrax as dfx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax 
from jax import Array
import equinox as eqx

class State(eqx.Module):
    y: Array
    aux: Array

def wrap_solver(solver_class):

    class WrappedSolver(solver_class):

        def update_aux(self, t0, x, aux):
            aux+=1
            return aux

        def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
            (y, y_error, dense_info, solver_state, solver_result) = super().step(terms, t0, t1, y0, args, solver_state, made_jump)
            x = y.y
            aux = y.aux
            new_aux = self.update_aux(t0, x, aux)

            return (State(x, new_aux), y_error, dense_info, solver_state, solver_result)

    return WrappedSolver

def ode_func(t, y, args):
    truey = y.y
    dy = -2 * truey
    daux = 0
    return State(dy, daux)

y0 = jnp.array([1.0])
t0 = 0.0             
t1 = 5.0           

solver = wrap_solver(dfx.Dopri5)()

stepsize_controller = dfx.PIDController(1e-6, 1e-6)
term = dfx.ODETerm(ode_func)
saveat = dfx.SaveAt(ts=jnp.linspace(t0, t1, 100))

solution = dfx.diffeqsolve(
    term,
    solver,
    t0=t0,
    t1=t1,
    dt0=0.1,
    y0=State(y0,0),
    saveat=saveat,
    stepsize_controller=stepsize_controller,
)

times = solution.ts
values = solution.ys

print(values.aux)       

Result gives

[ 0.  0.  1.  1.  1.  1.  2.  2.  2.  3.  3.  3.  4.  4.  4.  5.  5.  5.
  5.  6.  6.  6.  7.  7.  7.  7.  8.  8.  8.  8.  8.  9.  9.  9.  9. 10.
 10. 10. 10. 10. 11. 11. 11. 11. 11. 12. 12. 12. 12. 12. 12. 13. 13. 13.
 13. 13. 13. 14. 14. 14. 14. 14. 14. 14. 14. 15. 15. 15. 15. 15. 15. 15.
 15. 16. 16. 16. 16. 16. 16. 16. 16. 16. 17. 17. 17. 17. 17. 17. 17. 17.
 17. 17. 17. 18. 18. 18. 18. 18. 18. 18.]

owing to the interpolation problem

jusevitch commented 1 month ago

I'm interested in this feature as well. This would simplify a number of common robotic / control theoretic scenarios.

To give a concrete example of where this would be useful, consider the following scenario:

It currently isn't possible to perform this within a single diffeqsolve command because:

(Granted this would likely be possible to simulate with a sequence of separate ODE solver calls, but this likely wouldn't scale well when simulating large numbers of autonomous agents.)

I agree with @jbial and @etienney that it would be nice to have a way to evolve this auxiliary state without having it get integrated along with the ODE state.

Adding this feature would likely come with some practical and theoretical sharp edges, but as long as the user is fully informed I don't think it should be too much of an issue. Potential issues I can think of include possible non-differentiability of the solution and (for ODEs) potentially large numbers of discontinuities in the vector field.

lockwo commented 1 month ago

Having a clearer set of guidelines/tools for more complex auxiliary states is of interest to me as well. I will note that in the above MVC, I think if there is no vector field (so you don't need to integrate over the auxiliary, you just want to update it each solver step), you could do it with solver state instead, like:

import diffrax as dfx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax 
from jax import Array
import equinox as eqx

def wrap_solver(solver_class):

    class WrappedSolver(solver_class):

        def update_aux(self, t0, x, aux):
            return aux + 1

        def init(self, terms, t0, t1, y0, args):
            init_aux = jnp.array([0.0])
            return (super().init(terms, t0, t1, y0, args), init_aux)

        def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
            (y, y_error, dense_info, new_solver_state, solver_result) = super().step(terms, t0, t1, y0, args, solver_state[0], made_jump)
            x = y
            aux = solver_state[1]
            new_aux = self.update_aux(t0, x, aux)

            return (y, y_error, dense_info, (new_solver_state, new_aux), solver_result)

    return WrappedSolver

def ode_func(t, y, args):
    truey = y
    dy = -2 * truey
    return dy

y0 = jnp.array([1.0])
t0 = 0.0             
t1 = 5.0           

solver = wrap_solver(dfx.Dopri5)()

stepsize_controller = dfx.PIDController(1e-6, 1e-6)
term = dfx.ODETerm(ode_func)
saveat = dfx.SaveAt(ts=jnp.linspace(t0, t1, 100), solver_state=True)

solution = dfx.diffeqsolve(
    term,
    solver,
    t0=t0,
    t1=t1,
    dt0=0.1,
    y0=y0,
    saveat=saveat,
    stepsize_controller=stepsize_controller,
)

times = solution.ts
values = solution.ys

print(solution.ys)       
print("State", solution.solver_state[1])

This avoids the interpolation problem:

[[1.00000000e+00]
 [9.03923571e-01]
 [8.17078471e-01]
...
State [20.]
jusevitch commented 1 month ago

To add to the minimum example shown above, it would be useful to have an interface for passing the auxiliary state to the vector field call of the dynamical system.

Below is a modification of @lockwo's code to simulate a single integrator dynamical system iteratively tracking waypoints in the shape of a square. The auxiliary state is the current target waypoint, which needs to be passed into the vector field computation as the goal location.

I made this work by adding the auxiliary state to the args parameter within the solver step function, but this feels like a bit of a nasty hack. It would be nice if the signature of the ODETerm was something similar to term(t, y, args, aux=None), where aux is an optional argument, but I realize this would be a fairly significant breaking change.

Minimal example (expand to see code) ```python import diffrax as dfx import jax.numpy as jnp import matplotlib.pyplot as plt import jax from jax import Array import equinox as eqx class WaypointAux(eqx.Module): waypoints: Array waypoint_idx: int capture_rad: float def wrap_solver(solver_class): class WrappedSolver(solver_class): def update_aux(self, t0, x, args, aux): # Change to next waypoint if state is within capture radius of current waypoint aux = jax.lax.cond( jnp.linalg.norm(x - aux.waypoints[aux.waypoint_idx]) < aux.capture_rad, lambda: WaypointAux( aux.waypoints, aux.waypoint_idx + 1, aux.capture_rad ), lambda: aux, ) return aux def init(self, terms, t0, t1, y0, args): init_aux = WaypointAux( jnp.array([ [1.0, 1.0], [1.0, 2.0], [2.0, 2.0], [2.0, 1.0], [1.0, 1.0], ]), 0, 0.01 ) return (super().init(terms, t0, t1, y0, args), init_aux) def step(self, terms, t0, t1, y0, args, solver_state, made_jump): aux = solver_state[1] # Combine args with aux so that aux can be accessed by vector field function combined_args = (args, aux) (y, y_error, dense_info, new_solver_state, solver_result) = super().step(terms, t0, t1, y0, combined_args, solver_state[0], made_jump) x = y new_aux = self.update_aux(t0, x, args, aux) return (y, y_error, dense_info, (new_solver_state, new_aux), solver_result) return WrappedSolver def track_waypoint(t, y, args): args, goal = args return (goal.waypoints[goal.waypoint_idx] - y) y0 = jnp.array([0.0,0.0]) t0 = 0.0 t1 = 50.0 solver = wrap_solver(dfx.Dopri5)() stepsize_controller = dfx.PIDController(1e-6, 1e-6) term = dfx.ODETerm(track_waypoint) saveat = dfx.SaveAt(ts=jnp.linspace(t0, t1, 100), solver_state=True) solution = dfx.diffeqsolve( term, solver, t0=t0, t1=t1, dt0=0.1, y0=y0, saveat=saveat, stepsize_controller=stepsize_controller, ) times = solution.ts values = solution.ys print(solution.ys) print("State", solution.solver_state[1]) plt.plot(values[:,0], values[:,1]) plt.plot([1, 1, 2, 2, 1], [1, 2, 2, 1, 1], 'ro') plt.savefig("test_waypoints.png") ```
jbial commented 1 month ago

To add to the minimum example shown above, it would be useful to have an interface for passing the auxiliary state to the vector field call of the dynamical system.

Below is a modification of @lockwo's code to simulate a single integrator dynamical system iteratively tracking waypoints in the shape of a square. The auxiliary state is the current target waypoint, which needs to be passed into the vector field computation as the goal location.

I made this work by adding the auxiliary state to the args parameter within the solver step function, but this feels like a bit of a nasty hack. It would be nice if the signature of the ODETerm was something similar to term(t, y, args, aux=None), where aux is an optional argument, but I realize this would be a fairly significant breaking change.

Minimal example (expand to see code)

I like this solution actually, it's not that hacky and it's cleaner than artificially adding a zero vector field for the aux variable. Thanks for commenting everyone!

Edit: had to manually pass in args=(None, init_aux) into diffeqsolve to get the code to work

patrick-kidger commented 1 month ago

This is quite the outpouring of support for this!

So I'm pretty sure that implementing this at the solver level -- as a custom solver -- is the correct thing to do. Anything not part of the differentiable state can only be made sense of at the step-by-step level, and that is the abstraction that solvers represent. That said, as this issue demonstrates... perhaps we could do something to make this easier to accomplish! (Perhaps without having to grok the whole AbstractSolverAPI.)

As a first pass, here's an (untested) quick mockup of a general approach to this, that allows for updating arbitrary auxiliary state on each step:

class RequiresAux(eqx.Module):
    vector_field: Callable

    def __call__(self, t, y, args):
        args, aux_state = args
        return self.vector_field(t, y, args, aux_state=aux_state)

class AuxSolver(AbstractWrappedSolver):  # provides `.solver`
    init_aux: Callable
    step_aux: Callable

    def init(self, terms, t0, t1, y0, args):
        return self.solver.init(t0, t1, y0, args), self.init_aux(t0, t1, y0, args)

    def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
        solver_state, aux_state = solver_state
        y1, y_error, dense_info, new_solver_state, result = self.solver.step(terms, t0, t1, y0, (args, aux_state), solver_state, made_jump)
        new_aux_state = self.step_aux(terms, t0, t1, y0, args, aux_state, made_jump)
        return y1, y_error, dense_info (new_solver_state, new_aux_state), result

And here's some example usage, based on the discussion above:

class WaypointAux(eqx.Module):
    waypoints: Float[Array, "waypoints ..."]
    waypoint_idx: Int[Array, ""]
    capture_rad: Float[Array, ""]

def init_aux(t0, t1, y0, args):
    return WaypointAux(
        jnp.array([
            [1.0, 1.0],
            [1.0, 2.0],
            [2.0, 2.0],
            [2.0, 1.0],
            [1.0, 1.0],
        ]),
        0,
        0.01
    )

def update_aux(terms, t0, t1, y0, args, aux_state, made_jump):
    update_idx = jnp.linalg.norm(y0 - aux.waypoints[aux.waypoint_idx]) < aux.capture_rad
    idx = waypoint_idx + jnp.where(update_idx, 1, 0)
    return eqx.tree_at(lambda a: a.waypoint_idx, aux_state, idx)

def track_waypoint(t, y, args, *, aux_state):
    return (aux_state.waypoints[aux_state.waypoint_idx] - y)

solver = AuxSolver(Dopri5(), init_aux, update_aux)
terms = ODETerm(RequiresAux(track_waypoint))
...

Happy to think about standardising on something like the above, and adding it into Diffrax.

jbial commented 1 month ago

@patrick-kidger It seems like standardization of this framework seems useful for a lot of problems. Do you think the AuxSolver template you've outlined above is enough as an initial version? Happy to contribute if it requires more detail!

jusevitch commented 1 month ago

@patrick-kidger thanks for these thoughts. I agree that streamlining this method would be useful. I'm also happy to help contribute or test out a potential solution.

After some more thought, I don't know if introducing an aux_state keyword would be a good approach overall. It might be simpler to just use the args = (args, aux_state) method or similar and let the vector field function. handle its args input accordingly.

The only potential complication is that Diffrax takes gradients with respect to args by default, and it's not clear whether or not the aux_state should be included in those gradients by default. But that could be left to the user to handle.

patrick-kidger commented 1 month ago

Great! I'm glad this seems interesting.

So probably the main thing I'm bearing in mind for the design here is that it looks like we probably -- at some point -- want to be able to pass other information to the vector field as well. For example for delay differential equations, we will want to pass in some history:

def vector_field(t, y, args, *, history):
    ...

so really I'd like a way to pass in arbitrary extra information from the solver!

Probably that means something like the following design:

# This is analogous to the `RequiresAux` above.
class VectorField(eqx.Module):
    vector_field: Callable

    def __call__(self, t, y, args):
        args, kwargs = args
        return self.vector_field(t, y, args, **kwargs)

# Now have every term wrap their vector field on instantation.
class ODETerm(AbstractTerm):
    vector_field: Callable = eqx.field(converter=VectorField)

It might be simpler to just use the args = (args, aux_state) method

Just to address this explicitly -- I think this approach is probably reasonable for specifically aux_state, but when (as above) we may want to generalise to other kinds of extra information, this might get too complicated to be reasonable.

The only potential complication is that Diffrax takes gradients with respect to args by default, and it's not clear whether or not the aux_state should be included in those gradients by default. But that could be left to the user to handle.

I think we should expect to take gradients when using the default adjoint method. (Which I think will just work automatically, not special work required.) You're right though -- for BacksolveAdjoint we should add a check for this case, and throw an error.

All in all I'd be happy to take a PR from anyone for a first crack at this!

jbial commented 1 month ago

@patrick-kidger With some testing, I'm unable to get your example of inheriting from AbstractWrappedSolver working, I'm unable to get around some module initialization bugs. However, the def wrap_solver method originally mentioned (which references super() instead of self.solver) seems to work well, but I don't think it's very clean. Any thoughts on this?

Also, it would be nice to have a way to save the auxiliary state in time. Not too sure how to approach this. One thought is to preallocate an array for the auxiliary state trajectory, but then its unclear how to fill the array each step, evenly spaced, unless you know exactly how many solver steps will occur.

Edit: Only way I'm able to save the trajectory of the auxiliary state by simply making the vector field return (dxdt, 0) and making the auxiliary state part of the main trajectory. Obviously, this was the original solution and is not that clean.

patrick-kidger commented 1 month ago

I'm unable to get around some module initialization bugs.

Do you have a MWE?

Also, it would be nice to have a way to save the auxiliary state in time.

Put the evolving value in the solver_state! (Note that you'll likely get spurious copies if you also attempt to backprop through this state -- this is a tricky JAX limitation to work around.)

jusevitch commented 1 month ago

It would be nice if the times at which auxiliary state(s) are saved could be specified by SaveAt, similar to how the ODE state is saved. Currently the docs state that SaveAt can only return the solver state at the final time t1 (using the solver_state argument).

jbial commented 1 month ago

It would be nice if the times at which auxiliary state(s) are saved could be specified by SaveAt, similar to how the ODE state is saved. Currently the docs state that SaveAt can only return the solver state at the final time t1 (using the solver_state argument).

@patrick-kidger Yeah this is what I meant by saving the solver state in time. Here is a MWE that does the d[Aux]dx = 0 trick which makes the auxiliary state amenable to the SubSaveAt api:

import diffrax as dfx
import jax.numpy as jnp
from jax import random
from typing import Callable

def auxiliary_solver_factory(solver_class: dfx.AbstractSolver, step_aux: Callable) -> dfx.AbstractSolver:
    class SolverWithState(solver_class):

        def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
            (y, _), y_error, dense_info, new_solver_state, result = super().step(
                terms, t0, t1, y0, args, solver_state, made_jump
            )

            new_aux_state = step_aux(terms, t0, t1, y0, args, made_jump)
            return (y, new_aux_state), y_error, dense_info, new_solver_state, result
    return SolverWithState

def wrapped_rhs(rhs: Callable):
    def _rhs(t, y, args):
        x, _ = y
        dx = rhs(t, x, args)
        # (dxdt, d[AUx]dt)
        return (dx, 0)
    return _rhs

def step_aux(terms, t0, t1, y0, args, made_jump):
    A, B, _ = args
    x0, aux = y0
    return A@x0 + B@aux

def rhs(t, y, args):
    _, _, C = args
    return jnp.sin(C@y)

key = random.PRNGKey(0)
mat_keys = random.split(key, 3)
A, B, C = [random.normal(k, shape) for k, shape in zip(mat_keys, [(4, 2), (4, 4), (2,2)])]

y0 = jnp.array([0.0,0.0])
t0 = 0.0             
t1 = 300.0

solver = auxiliary_solver_factory(dfx.Dopri5, step_aux)
stepsize_controller = dfx.PIDController(1e-6, 1e-6)
_rhs = wrapped_rhs(rhs)

state_traj = dfx.SubSaveAt(ts=jnp.linspace(t0, t1, 100), fn=lambda t, y, _: y[0])
aux_traj = dfx.SubSaveAt(ts=jnp.linspace(t0, t1, 40), fn=lambda t, y, _: y[1])
saveat = dfx.SaveAt(subs=[state_traj, aux_traj])

sol = dfx.diffeqsolve(
    dfx.ODETerm(_rhs), solver(), t0=t0, t1=t1, dt0=0.1, y0=(y0, jnp.zeros(4)),
    args=(A, B, C), saveat=saveat, stepsize_controller=stepsize_controller,
)

state_ts, aux_ts = sol.ts
state_ys, aux_ys = sol.ys

print(state_ts.shape, aux_ts.shape)
print(state_ys.shape, aux_ys.shape)

I'm just wondering if there's a cleaner way to do this without adding that dummy vector field for the auxiliary state.

@jusevitch I tried to make this work for your example, but it prob requires some more pytree grokking. Unfortunately, this MWE avoids shipping the auxiliary state around with args and I understand how that's prob necessary for your example.

jusevitch commented 1 month ago

@jbial Thanks for passing along that MWE. I think there are two fundamental issues here that are preventing the saving of aux_state using the current SaveAt interface:

If there's a clean place in loop to add the solver_state into the args parameter, that might work while keeping the current SaveAt interface. Perhaps we could create some sort of function or wrapper to handle this.

On the other hand, if there was a SubSaveAt class that took additional parameters in its fn callable, e.g. fn(t,y,args,**kwargs), we could pass in the solver state and allow the user to parse that when saving.

patrick-kidger commented 1 month ago

So solver_state deliberately doesn't interact with SaveAt, as the latter evolves on each step, but the latter handles any combination of times/steps/etc. It wouldn't be defined to ask for the state at times not aligned with the step locations!

Rather, I am suggesting that you should allocate a buffer at the start of the loop, put it in solver_state, and then .at[].set() into it.

jusevitch commented 1 month ago

It wouldn't be defined to ask for the state at times not aligned with the step locations!

@patrick-kidger Fair point--I apologize for the lack of clarity. To be more specific, I believe my use case (and possibly @jbial's) could be considered an example of simulating outputs of basic hybrid automata (see also this paper). Hybrid systems have both states with continuous dynamics $x$ and discrete modes $q$. Both the continuous and discrete modes evolve as functions of each other.

An output $y = h(x,q)$ that is a function of both the continuous state and discrete mode can be well-defined at any time step that $x(t)$ is defined. The problem here is that the only way Diffrax can simulate and track $q$ is through the solver_state, but this can't currently be passed into SaveAt.

To give a specific example: for systems $\dot{x} = f(x,u(x,q),q)$ with control inputs $u$ it's fairly common to want to save the time history of control inputs $y(t) = u(x,q)$ alongside the state. In Diffrax the only way I know of to do this is to define a SubSaveAt where fn recomputes u(x,q) a second time. However for the waypoint example I gave above, this would require passing in the mode (i.e. current waypoint) q which can only be found in solver_state.

(Technically I guess u(x,q) or any other such y = h(x,q) could be reconstructed after the solve from solver_state, but it would be nice if this could be handled without that manual recomputation.)

Having said all that, I realize this scenario may be outside the scope of Diffrax. I'm a fan of the Equinox + Diffrax ecosystem though :slightly_smiling_face: and it would be beneficial to find some way to handle these cases. @jbial, @etienney, @lockwo I'm curious to know if this matches your needs; if not I'm happy to jump off this issue.

jbial commented 1 month ago

Rather, I am suggesting that you should allocate a buffer at the start of the loop, put it in solver_state, and then .at[].set() into it.

@patrick-kidger I thought to do this, but an issue arises when you want to save the solver state at, for example, evenly spaced intervals throughout the solve. The only easy solution I see requires always setting max_steps, preallocating that much buffer memory, and filling it. However, in my use cases, I'd much prefer to leave max_steps unspecified so that the solver can meet the tolerances I'm demanding - this makes filling a buffer in a structured manner unclear.

@jusevitch What you described does align with my needs and seems generally useful. At the very least, we could standardize a wrapped solver for this and add a basic example to the docs under the fixed max_steps assumption I mentioned? @patrick-kidger WDYT?

patrick-kidger commented 1 month ago

So my understanding is that these kinds of hybrid systems usually have a particular time at which the discrete mode switches? This means that a solver-based step-by-step probably isn't what you want, as the timesteps need not necessarily be aligned with the exact time at which the jump occurs!

For this reason the pattern I've been thinking of for such systems is to simulate them using an event that terminates the solve once the discrete condition has occurred. Then wrap this overall solve into a lax.while_loop (or an equinox.internal.while_loop if you need backpropagation), so that you start a new diffeqsolve for each piece. This then allows you to also perform arbitrary logic to update your discrete state in between each continuous solve.

Would this approach allow you to solve your systems?


As for leaving max_steps unspecified -- unfortunately this is incompatible, at the JAX level, with any approach in which you need to save an unknown number of values. JAX works exclusively with arrays whose size is known at compile time.

For example, try doing a solve with an adaptive step size controller and SaveAt(steps=True). In this case then the returned outputs will actually be of size max_steps. The first values will be filled in according to the steps that the solver actually took, and then after that they will be padded with jnp.inf.

etienney commented 1 month ago

So my understanding is that these kinds of hybrid systems usually have a particular time at which the discrete mode switches? This means that a solver-based step-by-step probably isn't what you want, as the timesteps need not necessarily be aligned with the exact time at which the jump occurs!

For this reason the pattern I've been thinking of for such systems is to simulate them using an event that terminates the solve once the discrete condition has occurred. Then wrap this overall solve into a lax.while_loop (or an equinox.internal.while_loop if you need backpropagation), so that you start a new diffeqsolve for each piece. This then allows you to also perform arbitrary logic to update your discrete state in between each continuous solve.

Would this approach allow you to solve your systems?

This is indeed exactly what I need and what I actually did in my special case. ( Though in my case I need to go out of diffrax to change the size of the objects according to the simulation's needs ( JAX would not allow for it. ), and enter it again. )

Having said all that, I realize this scenario may be outside the scope of Diffrax. I'm a fan of the Equinox + Diffrax ecosystem though 🙂 and it would be beneficial to find some way to handle these cases. @jbial, @etienney, @lockwo I'm curious to know if this matches your needs; if not I'm happy to jump off this issue.

I guess that such a framework would be useful in my case. Though not exactly what I do I could leverage it to be useful

jusevitch commented 1 month ago

For this reason the pattern I've been thinking of for such systems is to simulate them using an event that terminates the solve once the discrete condition has occurred. Then wrap this overall solve into a lax.while_loop (or an equinox.internal.while_loop if you need backpropagation), so that you start a new diffeqsolve for each piece. This then allows you to also perform arbitrary logic to update your discrete state in between each continuous solve.

Would this approach allow you to solve your systems?

I'll give this method a try. Thanks for sharing these suggestions.