Open jbial opened 4 months ago
I'm also very interested as it is similar to my problem #447 It is not the most satisfying solutio because :
daux = 0
in termsimport diffrax as dfx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
from jax import Array
import equinox as eqx
class State(eqx.Module):
y: Array
aux: Array
def wrap_solver(solver_class):
class WrappedSolver(solver_class):
def update_aux(self, t0, x, aux):
aux+=1
return aux
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
(y, y_error, dense_info, solver_state, solver_result) = super().step(terms, t0, t1, y0, args, solver_state, made_jump)
x = y.y
aux = y.aux
new_aux = self.update_aux(t0, x, aux)
return (State(x, new_aux), y_error, dense_info, solver_state, solver_result)
return WrappedSolver
def ode_func(t, y, args):
truey = y.y
dy = -2 * truey
daux = 0
return State(dy, daux)
y0 = jnp.array([1.0])
t0 = 0.0
t1 = 5.0
solver = wrap_solver(dfx.Dopri5)()
stepsize_controller = dfx.PIDController(1e-6, 1e-6)
term = dfx.ODETerm(ode_func)
saveat = dfx.SaveAt(ts=jnp.linspace(t0, t1, 100))
solution = dfx.diffeqsolve(
term,
solver,
t0=t0,
t1=t1,
dt0=0.1,
y0=State(y0,0),
saveat=saveat,
stepsize_controller=stepsize_controller,
)
times = solution.ts
values = solution.ys
print(values.aux)
Result gives
[ 0. 0. 1. 1. 1. 1. 2. 2. 2. 3. 3. 3. 4. 4. 4. 5. 5. 5.
5. 6. 6. 6. 7. 7. 7. 7. 8. 8. 8. 8. 8. 9. 9. 9. 9. 10.
10. 10. 10. 10. 11. 11. 11. 11. 11. 12. 12. 12. 12. 12. 12. 13. 13. 13.
13. 13. 13. 14. 14. 14. 14. 14. 14. 14. 14. 15. 15. 15. 15. 15. 15. 15.
15. 16. 16. 16. 16. 16. 16. 16. 16. 16. 17. 17. 17. 17. 17. 17. 17. 17.
17. 17. 17. 18. 18. 18. 18. 18. 18. 18.]
owing to the interpolation problem
I'm interested in this feature as well. This would simplify a number of common robotic / control theoretic scenarios.
To give a concrete example of where this would be useful, consider the following scenario:
It currently isn't possible to perform this within a single diffeqsolve
command because:
(Granted this would likely be possible to simulate with a sequence of separate ODE solver calls, but this likely wouldn't scale well when simulating large numbers of autonomous agents.)
I agree with @jbial and @etienney that it would be nice to have a way to evolve this auxiliary state without having it get integrated along with the ODE state.
Adding this feature would likely come with some practical and theoretical sharp edges, but as long as the user is fully informed I don't think it should be too much of an issue. Potential issues I can think of include possible non-differentiability of the solution and (for ODEs) potentially large numbers of discontinuities in the vector field.
Having a clearer set of guidelines/tools for more complex auxiliary states is of interest to me as well. I will note that in the above MVC, I think if there is no vector field (so you don't need to integrate over the auxiliary, you just want to update it each solver step), you could do it with solver state instead, like:
import diffrax as dfx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
from jax import Array
import equinox as eqx
def wrap_solver(solver_class):
class WrappedSolver(solver_class):
def update_aux(self, t0, x, aux):
return aux + 1
def init(self, terms, t0, t1, y0, args):
init_aux = jnp.array([0.0])
return (super().init(terms, t0, t1, y0, args), init_aux)
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
(y, y_error, dense_info, new_solver_state, solver_result) = super().step(terms, t0, t1, y0, args, solver_state[0], made_jump)
x = y
aux = solver_state[1]
new_aux = self.update_aux(t0, x, aux)
return (y, y_error, dense_info, (new_solver_state, new_aux), solver_result)
return WrappedSolver
def ode_func(t, y, args):
truey = y
dy = -2 * truey
return dy
y0 = jnp.array([1.0])
t0 = 0.0
t1 = 5.0
solver = wrap_solver(dfx.Dopri5)()
stepsize_controller = dfx.PIDController(1e-6, 1e-6)
term = dfx.ODETerm(ode_func)
saveat = dfx.SaveAt(ts=jnp.linspace(t0, t1, 100), solver_state=True)
solution = dfx.diffeqsolve(
term,
solver,
t0=t0,
t1=t1,
dt0=0.1,
y0=y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
)
times = solution.ts
values = solution.ys
print(solution.ys)
print("State", solution.solver_state[1])
This avoids the interpolation problem:
[[1.00000000e+00]
[9.03923571e-01]
[8.17078471e-01]
...
State [20.]
To add to the minimum example shown above, it would be useful to have an interface for passing the auxiliary state to the vector field call of the dynamical system.
Below is a modification of @lockwo's code to simulate a single integrator dynamical system iteratively tracking waypoints in the shape of a square. The auxiliary state is the current target waypoint, which needs to be passed into the vector field computation as the goal location.
I made this work by adding the auxiliary state to the args
parameter within the solver step
function, but this feels like a bit of a nasty hack. It would be nice if the signature of the ODETerm was something similar to term(t, y, args, aux=None)
, where aux
is an optional argument, but I realize this would be a fairly significant breaking change.
To add to the minimum example shown above, it would be useful to have an interface for passing the auxiliary state to the vector field call of the dynamical system.
Below is a modification of @lockwo's code to simulate a single integrator dynamical system iteratively tracking waypoints in the shape of a square. The auxiliary state is the current target waypoint, which needs to be passed into the vector field computation as the goal location.
I made this work by adding the auxiliary state to the
args
parameter within the solverstep
function, but this feels like a bit of a nasty hack. It would be nice if the signature of the ODETerm was something similar toterm(t, y, args, aux=None)
, whereaux
is an optional argument, but I realize this would be a fairly significant breaking change.Minimal example (expand to see code)
I like this solution actually, it's not that hacky and it's cleaner than artificially adding a zero vector field for the aux variable. Thanks for commenting everyone!
Edit: had to manually pass in args=(None, init_aux)
into diffeqsolve
to get the code to work
This is quite the outpouring of support for this!
So I'm pretty sure that implementing this at the solver level -- as a custom solver -- is the correct thing to do. Anything not part of the differentiable state can only be made sense of at the step-by-step level, and that is the abstraction that solvers represent. That said, as this issue demonstrates... perhaps we could do something to make this easier to accomplish! (Perhaps without having to grok the whole AbstractSolver
API.)
As a first pass, here's an (untested) quick mockup of a general approach to this, that allows for updating arbitrary auxiliary state on each step:
class RequiresAux(eqx.Module):
vector_field: Callable
def __call__(self, t, y, args):
args, aux_state = args
return self.vector_field(t, y, args, aux_state=aux_state)
class AuxSolver(AbstractWrappedSolver): # provides `.solver`
init_aux: Callable
step_aux: Callable
def init(self, terms, t0, t1, y0, args):
return self.solver.init(t0, t1, y0, args), self.init_aux(t0, t1, y0, args)
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
solver_state, aux_state = solver_state
y1, y_error, dense_info, new_solver_state, result = self.solver.step(terms, t0, t1, y0, (args, aux_state), solver_state, made_jump)
new_aux_state = self.step_aux(terms, t0, t1, y0, args, aux_state, made_jump)
return y1, y_error, dense_info (new_solver_state, new_aux_state), result
And here's some example usage, based on the discussion above:
class WaypointAux(eqx.Module):
waypoints: Float[Array, "waypoints ..."]
waypoint_idx: Int[Array, ""]
capture_rad: Float[Array, ""]
def init_aux(t0, t1, y0, args):
return WaypointAux(
jnp.array([
[1.0, 1.0],
[1.0, 2.0],
[2.0, 2.0],
[2.0, 1.0],
[1.0, 1.0],
]),
0,
0.01
)
def update_aux(terms, t0, t1, y0, args, aux_state, made_jump):
update_idx = jnp.linalg.norm(y0 - aux.waypoints[aux.waypoint_idx]) < aux.capture_rad
idx = waypoint_idx + jnp.where(update_idx, 1, 0)
return eqx.tree_at(lambda a: a.waypoint_idx, aux_state, idx)
def track_waypoint(t, y, args, *, aux_state):
return (aux_state.waypoints[aux_state.waypoint_idx] - y)
solver = AuxSolver(Dopri5(), init_aux, update_aux)
terms = ODETerm(RequiresAux(track_waypoint))
...
Happy to think about standardising on something like the above, and adding it into Diffrax.
@patrick-kidger It seems like standardization of this framework seems useful for a lot of problems. Do you think the AuxSolver
template you've outlined above is enough as an initial version? Happy to contribute if it requires more detail!
@patrick-kidger thanks for these thoughts. I agree that streamlining this method would be useful. I'm also happy to help contribute or test out a potential solution.
After some more thought, I don't know if introducing an aux_state
keyword would be a good approach overall. It might be simpler to just use the args = (args, aux_state)
method or similar and let the vector field function. handle its args
input accordingly.
The only potential complication is that Diffrax takes gradients with respect to args
by default, and it's not clear whether or not the aux_state
should be included in those gradients by default. But that could be left to the user to handle.
Great! I'm glad this seems interesting.
So probably the main thing I'm bearing in mind for the design here is that it looks like we probably -- at some point -- want to be able to pass other information to the vector field as well. For example for delay differential equations, we will want to pass in some history
:
def vector_field(t, y, args, *, history):
...
so really I'd like a way to pass in arbitrary extra information from the solver!
Probably that means something like the following design:
# This is analogous to the `RequiresAux` above.
class VectorField(eqx.Module):
vector_field: Callable
def __call__(self, t, y, args):
args, kwargs = args
return self.vector_field(t, y, args, **kwargs)
# Now have every term wrap their vector field on instantation.
class ODETerm(AbstractTerm):
vector_field: Callable = eqx.field(converter=VectorField)
It might be simpler to just use the
args = (args, aux_state)
method
Just to address this explicitly -- I think this approach is probably reasonable for specifically aux_state
, but when (as above) we may want to generalise to other kinds of extra information, this might get too complicated to be reasonable.
The only potential complication is that Diffrax takes gradients with respect to args by default, and it's not clear whether or not the aux_state should be included in those gradients by default. But that could be left to the user to handle.
I think we should expect to take gradients when using the default adjoint method. (Which I think will just work automatically, not special work required.)
You're right though -- for BacksolveAdjoint
we should add a check for this case, and throw an error.
All in all I'd be happy to take a PR from anyone for a first crack at this!
@patrick-kidger With some testing, I'm unable to get your example of inheriting from AbstractWrappedSolver
working, I'm unable to get around some module initialization bugs. However, the def wrap_solver
method originally mentioned (which references super()
instead of self.solver
) seems to work well, but I don't think it's very clean. Any thoughts on this?
Also, it would be nice to have a way to save the auxiliary state in time. Not too sure how to approach this. One thought is to preallocate an array for the auxiliary state trajectory, but then its unclear how to fill the array each step, evenly spaced, unless you know exactly how many solver steps will occur.
Edit: Only way I'm able to save the trajectory of the auxiliary state by simply making the vector field return (dxdt, 0)
and making the auxiliary state part of the main trajectory. Obviously, this was the original solution and is not that clean.
I'm unable to get around some module initialization bugs.
Do you have a MWE?
Also, it would be nice to have a way to save the auxiliary state in time.
Put the evolving value in the solver_state
! (Note that you'll likely get spurious copies if you also attempt to backprop through this state -- this is a tricky JAX limitation to work around.)
It would be nice if the times at which auxiliary state(s) are saved could be specified by SaveAt
, similar to how the ODE state is saved. Currently the docs state that SaveAt
can only return the solver state at the final time t1
(using the solver_state
argument).
It would be nice if the times at which auxiliary state(s) are saved could be specified by
SaveAt
, similar to how the ODE state is saved. Currently the docs state thatSaveAt
can only return the solver state at the final timet1
(using thesolver_state
argument).
@patrick-kidger Yeah this is what I meant by saving the solver state in time. Here is a MWE that does the d[Aux]dx = 0
trick which makes the auxiliary state amenable to the SubSaveAt
api:
import diffrax as dfx
import jax.numpy as jnp
from jax import random
from typing import Callable
def auxiliary_solver_factory(solver_class: dfx.AbstractSolver, step_aux: Callable) -> dfx.AbstractSolver:
class SolverWithState(solver_class):
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
(y, _), y_error, dense_info, new_solver_state, result = super().step(
terms, t0, t1, y0, args, solver_state, made_jump
)
new_aux_state = step_aux(terms, t0, t1, y0, args, made_jump)
return (y, new_aux_state), y_error, dense_info, new_solver_state, result
return SolverWithState
def wrapped_rhs(rhs: Callable):
def _rhs(t, y, args):
x, _ = y
dx = rhs(t, x, args)
# (dxdt, d[AUx]dt)
return (dx, 0)
return _rhs
def step_aux(terms, t0, t1, y0, args, made_jump):
A, B, _ = args
x0, aux = y0
return A@x0 + B@aux
def rhs(t, y, args):
_, _, C = args
return jnp.sin(C@y)
key = random.PRNGKey(0)
mat_keys = random.split(key, 3)
A, B, C = [random.normal(k, shape) for k, shape in zip(mat_keys, [(4, 2), (4, 4), (2,2)])]
y0 = jnp.array([0.0,0.0])
t0 = 0.0
t1 = 300.0
solver = auxiliary_solver_factory(dfx.Dopri5, step_aux)
stepsize_controller = dfx.PIDController(1e-6, 1e-6)
_rhs = wrapped_rhs(rhs)
state_traj = dfx.SubSaveAt(ts=jnp.linspace(t0, t1, 100), fn=lambda t, y, _: y[0])
aux_traj = dfx.SubSaveAt(ts=jnp.linspace(t0, t1, 40), fn=lambda t, y, _: y[1])
saveat = dfx.SaveAt(subs=[state_traj, aux_traj])
sol = dfx.diffeqsolve(
dfx.ODETerm(_rhs), solver(), t0=t0, t1=t1, dt0=0.1, y0=(y0, jnp.zeros(4)),
args=(A, B, C), saveat=saveat, stepsize_controller=stepsize_controller,
)
state_ts, aux_ts = sol.ts
state_ys, aux_ys = sol.ys
print(state_ts.shape, aux_ts.shape)
print(state_ys.shape, aux_ys.shape)
I'm just wondering if there's a cleaner way to do this without adding that dummy vector field for the auxiliary state.
@jusevitch I tried to make this work for your example, but it prob requires some more pytree grokking. Unfortunately, this MWE avoids shipping the auxiliary state around with args
and I understand how that's prob necessary for your example.
@jbial Thanks for passing along that MWE. I think there are two fundamental issues here that are preventing the saving of aux_state
using the current SaveAt
interface:
SubSaveAt
only takes t, y, args
as input parameters.loop
function inherently assumes that the args
variable is static for the entire solve process. If there's a clean place in loop
to add the solver_state
into the args
parameter, that might work while keeping the current SaveAt
interface. Perhaps we could create some sort of function or wrapper to handle this.
On the other hand, if there was a SubSaveAt
class that took additional parameters in its fn
callable, e.g. fn(t,y,args,**kwargs)
, we could pass in the solver state and allow the user to parse that when saving.
So solver_state
deliberately doesn't interact with SaveAt
, as the latter evolves on each step, but the latter handles any combination of times/steps/etc. It wouldn't be defined to ask for the state at times not aligned with the step locations!
Rather, I am suggesting that you should allocate a buffer at the start of the loop, put it in solver_state
, and then .at[].set()
into it.
It wouldn't be defined to ask for the state at times not aligned with the step locations!
@patrick-kidger Fair point--I apologize for the lack of clarity. To be more specific, I believe my use case (and possibly @jbial's) could be considered an example of simulating outputs of basic hybrid automata (see also this paper). Hybrid systems have both states with continuous dynamics $x$ and discrete modes $q$. Both the continuous and discrete modes evolve as functions of each other.
An output $y = h(x,q)$ that is a function of both the continuous state and discrete mode can be well-defined at any time step that $x(t)$ is defined. The problem here is that the only way Diffrax can simulate and track $q$ is through the solver_state
, but this can't currently be passed into SaveAt
.
To give a specific example: for systems $\dot{x} = f(x,u(x,q),q)$ with control inputs $u$ it's fairly common to want to save the time history of control inputs $y(t) = u(x,q)$ alongside the state. In Diffrax the only way I know of to do this is to define a SubSaveAt
where fn
recomputes u(x,q)
a second time. However for the waypoint example I gave above, this would require passing in the mode (i.e. current waypoint) q
which can only be found in solver_state
.
(Technically I guess u(x,q)
or any other such y = h(x,q)
could be reconstructed after the solve from solver_state
, but it would be nice if this could be handled without that manual recomputation.)
Having said all that, I realize this scenario may be outside the scope of Diffrax. I'm a fan of the Equinox + Diffrax ecosystem though :slightly_smiling_face: and it would be beneficial to find some way to handle these cases. @jbial, @etienney, @lockwo I'm curious to know if this matches your needs; if not I'm happy to jump off this issue.
Rather, I am suggesting that you should allocate a buffer at the start of the loop, put it in
solver_state
, and then.at[].set()
into it.
@patrick-kidger I thought to do this, but an issue arises when you want to save the solver state at, for example, evenly spaced intervals throughout the solve. The only easy solution I see requires always setting max_steps
, preallocating that much buffer memory, and filling it. However, in my use cases, I'd much prefer to leave max_steps
unspecified so that the solver can meet the tolerances I'm demanding - this makes filling a buffer in a structured manner unclear.
@jusevitch What you described does align with my needs and seems generally useful. At the very least, we could standardize a wrapped solver for this and add a basic example to the docs under the fixed max_steps
assumption I mentioned? @patrick-kidger WDYT?
So my understanding is that these kinds of hybrid systems usually have a particular time at which the discrete mode switches? This means that a solver-based step-by-step probably isn't what you want, as the timesteps need not necessarily be aligned with the exact time at which the jump occurs!
For this reason the pattern I've been thinking of for such systems is to simulate them using an event that terminates the solve once the discrete condition has occurred. Then wrap this overall solve into a lax.while_loop
(or an equinox.internal.while_loop
if you need backpropagation), so that you start a new diffeqsolve
for each piece. This then allows you to also perform arbitrary logic to update your discrete state in between each continuous solve.
Would this approach allow you to solve your systems?
As for leaving max_steps
unspecified -- unfortunately this is incompatible, at the JAX level, with any approach in which you need to save an unknown number of values. JAX works exclusively with arrays whose size is known at compile time.
For example, try doing a solve with an adaptive step size controller and SaveAt(steps=True)
. In this case then the returned outputs will actually be of size max_steps
. The first values will be filled in according to the steps that the solver actually took, and then after that they will be padded with jnp.inf
.
So my understanding is that these kinds of hybrid systems usually have a particular time at which the discrete mode switches? This means that a solver-based step-by-step probably isn't what you want, as the timesteps need not necessarily be aligned with the exact time at which the jump occurs!
For this reason the pattern I've been thinking of for such systems is to simulate them using an event that terminates the solve once the discrete condition has occurred. Then wrap this overall solve into a
lax.while_loop
(or anequinox.internal.while_loop
if you need backpropagation), so that you start a newdiffeqsolve
for each piece. This then allows you to also perform arbitrary logic to update your discrete state in between each continuous solve.Would this approach allow you to solve your systems?
This is indeed exactly what I need and what I actually did in my special case. ( Though in my case I need to go out of diffrax to change the size of the objects according to the simulation's needs ( JAX would not allow for it. ), and enter it again. )
Having said all that, I realize this scenario may be outside the scope of Diffrax. I'm a fan of the Equinox + Diffrax ecosystem though 🙂 and it would be beneficial to find some way to handle these cases. @jbial, @etienney, @lockwo I'm curious to know if this matches your needs; if not I'm happy to jump off this issue.
I guess that such a framework would be useful in my case. Though not exactly what I do I could leverage it to be useful
For this reason the pattern I've been thinking of for such systems is to simulate them using an event that terminates the solve once the discrete condition has occurred. Then wrap this overall solve into a
lax.while_loop
(or anequinox.internal.while_loop
if you need backpropagation), so that you start a newdiffeqsolve
for each piece. This then allows you to also perform arbitrary logic to update your discrete state in between each continuous solve.Would this approach allow you to solve your systems?
I'll give this method a try. Thanks for sharing these suggestions.
Given any vector field
f(t, y, args)
, I want to keep track of an auxiliary variablew
that depends and evolves withy
but non-differentiably in the sense that there is no vector field forw
, it's just a state that gets updated everystep
call. For example, at every step of the integrator I would ideally like to be able to do something like the following pseudocode:I'm currently trying to wrap an arbitrary solver with some success like so:
But this seems very janky and I'm wondering if there's a cleaner way without exposing the
step
method. It's as if I need a custom term where one term is an ODETerm with control dt and the other is like a "StateTerm" also with control dt but it doesn't get integrated.