Closed adam-hartshorne closed 1 year ago
Hi! Thanks for dropping by :)
What exactly is your question? If it is about using ProbDiffEq for NODEs:
Reverse-mode differentiation of simulate_terminal_values
(akin to e.g. diffrax.BacksolveAdjoint
) is a work in progress. For the time being, we must use fixed time-steps instead.
For example, here is an example notebook that does something similar to your example but with two differences:
jnp.mean(jnp.square(ekf0sol.u - yt))
by a probabilistic equivalent (which takes the statistical nature of the probabilistic solution into account -- something a non-probabilitic solver cannot do). Does this help? What do you think?
Thanks for the quick response.
I did see that example, but as I understand it, it requires you to have data along the path. If you have dataset where you only have a set of input locations and terminal locations, I couldn't see how I could use that example?
I am interested in a probabilistic solution to such a setup.
Ah, I see!
Essentially, one would only replace ekf0sol.u
with ekf0sol.u[-1]
in your example.
To adapt the NODE example notebook (including the loss function I mentioned above), replace the loss_fn with something like the logposterior_fn from the sampling example (i.e. the BlackJAX example, which deals with terminal-value data):
@jax.jit
def logposterior_fn(theta, *, data, ts, solver, obs_stdev=0.1):
y_T = solve_fixed(theta, ts=ts, solver=solver)
marginals, _ = y_T.posterior.condition_on_qoi_observation(
data, observation_std=obs_stdev
)
return marginals.logpdf(data) # removed prior PDF from notebook
# Fixed steps for reverse-mode differentiability:
@jax.jit
def solve_fixed(theta, *, ts, solver):
sol = solution_routines.solve_fixed_grid(
vf, initial_values=(theta,), grid=ts, solver=solver
)
return sol[-1]
In general, the sampling example might be useful to look at if you deal with terminal value data. But in general, the differences between the kinds of data are comparably small.
Does this help?
Great, thank you very much for your help. My misunderstanding was that the fixed grid methods were for use exclusively on datasets in which you have trajectory data, not just the terminal values.
FYI, Diffrax makes use of special loops (which are defined in Equinox) including ones to efficiently handle adaptive solves, and allow for reverse mode differentiation. I would have thought you could probably build off those for your use case.
I will give it a try and see how I get on with my actual use case.
Awesome, glad to hear that! If you run into more problems/misunderstandings, don't hesitate to ask more questions.
FYI, Diffrax makes use of special loops (which are defined in Equinox) including ones to efficiently handle adaptive solves, and allow for reverse mode differentiation. I would have thought you could probably build off those for your use case.
Yes, I am aware of the bounded while loops and see how such functionality could be helpful. I made a note about a potential path forward #453; if you're keen on this extension, let's continue discussing there :)
Feel free to close this issue if your original question is resolved; if not, let me know.
Sorry if this is a stupid question, when handling a problem where we have n points in m dimensional space (e.g. 10 x 2d) in which we know their initial location and final position, lets call them X and Y.
After looking at this example https://pnkraemer.github.io/probdiffeq/benchmarks/pleiades/external/ , I am right in thinking that the initial_values is a flattened version of X i.e. tuple, where first element is array of shape (nm,) e.g (20,) ? And then we reshape back to (10,2) in the f that handles the vector field function?
And in terms of
marginals, _ = y_T.posterior.condition_on_qoi_observation(data, observation_std=obs_stdev)
return marginals.logpdf(data) # removed prior PDF from notebook
here data refers to flatten version of Y e.g. shape (20,) ?
Are you referring to matrix-valued differential equations? I.e. d/dt M(t) = f(M(t)), where M(t) is a matrix, not a vector?
In this case, I'd say you're right; rewriting this equation as a vector-valued (i.e. flattened) version seems to make sense. Instead of a (10,2)-shaped equation, one would solve a (20,)-shaped equation, and all derived quantities (e.g. data
in your example) would be reshaped accordingly.
Does that help?
I am trying to learn a vector flow field, as defined by an NODE, which models the advection of a set of points, given we know their start and end locations in 2d.
I see. I think that, for the moment, "flattening the equation" is the best way forward. I noted a potential extension of ProbDiffEq to matrix-valued equations in #457.
Since we're kind of drifting away from the original question (about simulate_terminal_values), I will close this issue for now. Please reopen if the original question has not been answered yet!
Let's move the discussion about matrix-valued equations to #457 :) And please feel invited to open more issues if you run into more problems!
I apologise in advance as I may have misunderstood something obvious, as I haven't used probabilistic ODE solvers before and am coming from using Diffrax.
If one wants to use simulate_terminal_values when using a NODE, due to the use of lax.while_loop in the _advance_ivp_solution_adaptively method this isn't going to be possible e.g. such as in the silly minimal example shown below, because lax.while_loop doesn't support reverse mode optimisation.