pnkraemer / probdiffeq

Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem.
https://pnkraemer.github.io/probdiffeq/
MIT License
32 stars 2 forks source link

Question about optimisation using simulate_terminal_values #452

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

I apologise in advance as I may have misunderstood something obvious, as I haven't used probabilistic ODE solvers before and am coming from using Diffrax.

If one wants to use simulate_terminal_values when using a NODE, due to the use of lax.while_loop in the _advance_ivp_solution_adaptively method this isn't going to be possible e.g. such as in the silly minimal example shown below, because lax.while_loop doesn't support reverse mode optimisation.

import jax
from jax import grad, jit, 
import jax.numpy as jnp
import optax
from diffeqzoo import backend, ivps
from probdiffeq import solvers, solution_routines
from probdiffeq.implementations import filters
from probdiffeq.strategies import filters

backend.select("jax")
f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1))
yt = 5.0

@jax.jit
def vf(y, *, t, p):
    return f(y, t, *p)

strategy = filters.Filter(recipes.IsoTS0.from_params(num_derivatives=4))
solver = solvers.DynamicSolver(strategy)

optim = optax.adam(learning_rate=1e-2)
p = f_args
state = optim.init(p)

def loss_fn(p):

    ekf0sol = solution_routines.simulate_terminal_values(
        vf,
        initial_values=(u0,),
        t0=0.0,
        t1=1.0,
        solver=solver,
        parameters=p,
    )

    return jnp.mean(jnp.square(ekf0sol.u - yt))

@jax.jit
def update(params, opt_state):
        loss, grads = jax.value_and_grad(loss_fn)(params)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state

for i in range(config.num_iterations):
    p, state = update_fn(p, state)
pnkraemer commented 1 year ago

Hi! Thanks for dropping by :)

What exactly is your question? If it is about using ProbDiffEq for NODEs:

Reverse-mode differentiation of simulate_terminal_values (akin to e.g. diffrax.BacksolveAdjoint) is a work in progress. For the time being, we must use fixed time-steps instead.

For example, here is an example notebook that does something similar to your example but with two differences:

Does this help? What do you think?

adam-hartshorne commented 1 year ago

Thanks for the quick response.

I did see that example, but as I understand it, it requires you to have data along the path. If you have dataset where you only have a set of input locations and terminal locations, I couldn't see how I could use that example?

I am interested in a probabilistic solution to such a setup.

pnkraemer commented 1 year ago

Ah, I see!

Essentially, one would only replace ekf0sol.u with ekf0sol.u[-1] in your example.

To adapt the NODE example notebook (including the loss function I mentioned above), replace the loss_fn with something like the logposterior_fn from the sampling example (i.e. the BlackJAX example, which deals with terminal-value data):

@jax.jit
def logposterior_fn(theta, *, data, ts, solver, obs_stdev=0.1):
    y_T = solve_fixed(theta, ts=ts, solver=solver)
    marginals, _ = y_T.posterior.condition_on_qoi_observation(
        data, observation_std=obs_stdev
    )
    return marginals.logpdf(data)  # removed prior PDF from notebook

# Fixed steps for reverse-mode differentiability:

@jax.jit
def solve_fixed(theta, *, ts, solver):
    sol = solution_routines.solve_fixed_grid(
        vf, initial_values=(theta,), grid=ts, solver=solver
    )
    return sol[-1]

In general, the sampling example might be useful to look at if you deal with terminal value data. But in general, the differences between the kinds of data are comparably small.

Does this help?

adam-hartshorne commented 1 year ago

Great, thank you very much for your help. My misunderstanding was that the fixed grid methods were for use exclusively on datasets in which you have trajectory data, not just the terminal values.

FYI, Diffrax makes use of special loops (which are defined in Equinox) including ones to efficiently handle adaptive solves, and allow for reverse mode differentiation. I would have thought you could probably build off those for your use case.

I will give it a try and see how I get on with my actual use case.

pnkraemer commented 1 year ago

Awesome, glad to hear that! If you run into more problems/misunderstandings, don't hesitate to ask more questions.

FYI, Diffrax makes use of special loops (which are defined in Equinox) including ones to efficiently handle adaptive solves, and allow for reverse mode differentiation. I would have thought you could probably build off those for your use case.

Yes, I am aware of the bounded while loops and see how such functionality could be helpful. I made a note about a potential path forward #453; if you're keen on this extension, let's continue discussing there :)

Feel free to close this issue if your original question is resolved; if not, let me know.

adam-hartshorne commented 1 year ago

Sorry if this is a stupid question, when handling a problem where we have n points in m dimensional space (e.g. 10 x 2d) in which we know their initial location and final position, lets call them X and Y.

After looking at this example https://pnkraemer.github.io/probdiffeq/benchmarks/pleiades/external/ , I am right in thinking that the initial_values is a flattened version of X i.e. tuple, where first element is array of shape (nm,) e.g (20,) ? And then we reshape back to (10,2) in the f that handles the vector field function?

And in terms of

 marginals, _ = y_T.posterior.condition_on_qoi_observation(data, observation_std=obs_stdev)
 return marginals.logpdf(data)  # removed prior PDF from notebook

here data refers to flatten version of Y e.g. shape (20,) ?

pnkraemer commented 1 year ago

Are you referring to matrix-valued differential equations? I.e. d/dt M(t) = f(M(t)), where M(t) is a matrix, not a vector?

In this case, I'd say you're right; rewriting this equation as a vector-valued (i.e. flattened) version seems to make sense. Instead of a (10,2)-shaped equation, one would solve a (20,)-shaped equation, and all derived quantities (e.g. data in your example) would be reshaped accordingly.

Does that help?

adam-hartshorne commented 1 year ago

I am trying to learn a vector flow field, as defined by an NODE, which models the advection of a set of points, given we know their start and end locations in 2d.

pnkraemer commented 1 year ago

I see. I think that, for the moment, "flattening the equation" is the best way forward. I noted a potential extension of ProbDiffEq to matrix-valued equations in #457.

Since we're kind of drifting away from the original question (about simulate_terminal_values), I will close this issue for now. Please reopen if the original question has not been answered yet!

Let's move the discussion about matrix-valued equations to #457 :) And please feel invited to open more issues if you run into more problems!