patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 133 forks source link

Support for hybrid (mixed discrete-continuous) systems #423

Open allen-adastra opened 6 months ago

allen-adastra commented 6 months ago

https://github.com/patrick-kidger/diffrax/issues/343 seems somewhat related but not quite the same.

It would be great to support discrete mode transitions between time steps. We can of course use the example in https://docs.kidger.site/diffrax/usage/manual-stepping/, but it seems that doesn't directly allow usage of adjoints?

To be honest, I'm not sure what the theory of adjoints is for hybrid systems is either.

Would be great to get thoughts on both fronts.

patrick-kidger commented 6 months ago

Probably this can/should be done using events (#387, an upcoming feature!) to detect whenever you want a discrete transition, and halting the solve. You can apply whatever change you would like, and then kick off another diffeqsolve afterwards (perhaps by placing the whole solve inside of a lax.while_loop or lax.scan).

Does that sound about right?

allen-adastra commented 6 months ago

Gotcha, thanks!

So if I were to have a state variable that is like a "mode" (perhaps in 0, 1, 2, 3), and we want transitions between the modes, what we would do is write some condition function transition_state() that returns a bool?

More generally, do you have a recommended resource on differentiable events?

patrick-kidger commented 6 months ago

Yup. I think this is the main resource on differentiable events: https://arxiv.org/abs/2011.03902

allen-adastra commented 5 months ago

Yup. I think this is the main resource on differentiable events: https://arxiv.org/abs/2011.03902

Gotcha thanks. Revisiting this cause I have been finding the practical need to generally support states that aren't differential in time but transition from time step to time step. So sorta like an event is happening at every time step.

x_{t+1} = f(x_t)

Wondering if you have thoughts/guidance on handling this?

patrick-kidger commented 5 months ago

It's not clear to me that this is even a differential equation? You seem to have a discrete system.