patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.33k stars 122 forks source link

Gradient calculation problems #307

Closed lmriccardo closed 10 months ago

lmriccardo commented 10 months ago

Hi, I'm using Diffrax to implement a tool for simulation and parameter estimation of System Biological Models. It is known that these kind of models are, in general, DAEs (not just ODEs) with possible Events. For this reason I have implemented the simulation in this way as you answered me in the Issue #261 (to also update the SaveAt and the StepsizeController):

def _simulate(ode_term: ODETerm, dt0: float, x0: Array, args0: Array, ts: Array):
    def step(carry, t):
        y, args = carry
        t0 = t
        t1 = t0 + dt0
        y, args = apply_events(t1, y, args)
        solution = diffeqsolve(ode_term, t0, t1, y, args, ...)
        y, args = solution  # Just takes y and args, not in this way obviously
        args = apply_algebraic(t1, y, args)
        return (y, args), y

    i_carry = (x0, args0)
    result = jax.lax.scan(step, i_carry, ts)

@eqx.filter_jit
def simulate(...):
    # Do some computations and call _simulate

At this point, once I have the results of the simulation, I have to compute the objective function, a simple Mean Squared Error, between some target measurements and some results computing starting from the simulation results. At the end, since I would like to minimize this loss, I need to compute the gradient. Now, these parameters for which I have to compute the gradients, let call them theta, are substituted in some position of args, meaning that the computation of the gradient depends also on the entire simulated trajectory.

Let be something like this:

@eqx.filter_value_and_grad
def objective(theta: Array, x0: Array, args0: Array, model: eqx.Module, targets: Array):
    args0 = args0.at[...].set(theta.at[...])
    sim_results = simulate(model, x0, args0, ...)
    return compute_mse(targets, f(sim_results))

The gradient can be computed sice the result from diffrax is differentiable. However, the computation of the gradient is really slow, and it can also costs more time than the actual simulation. I understand that it is not simple to backpropagate through the entire simulation trajectory and more. Despite this, I'm trying to find an approach to make the gradient computation faster.

Do you have any suggestions?

Sorry for the long description, but I thought it was necessary to better understand what's going on.

Thank for this beautiful framework you are working on!!!!

patrick-kidger commented 10 months ago

Hard to say from what you've said, I'm afraid. It can sometimes just be that ODE solves are expensise. How are you solving for the algebraic constraints? Are you using diffrax.NewtonNonlinerSolver, or something else?

lmriccardo commented 10 months ago

Sorry, they are not really algebraic constraints in the form $0 = f(x)$, they are just some kind of assignments to elements of the args vector, like: args0 = y[1] * k1 + y[0] and then args = args.at[0].set(args0). Now, in my understanding of diffrax, the vectorial field must returns only the result of the computation of the RHS of the differential equations, and for this reason I've created that algebraic function and the entire step-by-step interface: I need also the "trajectory" of args.

patrick-kidger commented 10 months ago

Gotcha. If you can reduce this to a MWE demonstrating an issue just with Diffrax then I'll be able to help, but without that it's hard to say whether there's actually an issue here.

lmriccardo commented 10 months ago

@patrick-kidger Hi, after all I decided to close this issue. In my opinion, it is better to wait that an actual DAE solver implementation with (non-terminating) event handling will be available. Hence, I have decided to move the attention on derivative-free optimization and local search.

You are free to decide either to remove the issue or not. It is up to you. I have another question, but I will create anothe Issue.

Thanks.