Closed lmriccardo closed 10 months ago
Hard to say from what you've said, I'm afraid. It can sometimes just be that ODE solves are expensise.
How are you solving for the algebraic constraints? Are you using diffrax.NewtonNonlinerSolver
, or something else?
Sorry, they are not really algebraic constraints in the form $0 = f(x)$, they are just some kind of assignments to elements of the args
vector, like: args0 = y[1] * k1 + y[0]
and then args = args.at[0].set(args0)
. Now, in my understanding of diffrax, the vectorial field must returns only the result of the computation of the RHS of the differential equations, and for this reason I've created that algebraic function and the entire step-by-step interface: I need also the "trajectory" of args.
Gotcha. If you can reduce this to a MWE demonstrating an issue just with Diffrax then I'll be able to help, but without that it's hard to say whether there's actually an issue here.
@patrick-kidger Hi, after all I decided to close this issue. In my opinion, it is better to wait that an actual DAE solver implementation with (non-terminating) event handling will be available. Hence, I have decided to move the attention on derivative-free optimization and local search.
You are free to decide either to remove the issue or not. It is up to you. I have another question, but I will create anothe Issue.
Thanks.
Hi, I'm using Diffrax to implement a tool for simulation and parameter estimation of System Biological Models. It is known that these kind of models are, in general, DAEs (not just ODEs) with possible Events. For this reason I have implemented the simulation in this way as you answered me in the Issue #261 (to also update the SaveAt and the StepsizeController):
At this point, once I have the results of the simulation, I have to compute the objective function, a simple Mean Squared Error, between some target measurements and some results computing starting from the simulation results. At the end, since I would like to minimize this loss, I need to compute the gradient. Now, these parameters for which I have to compute the gradients, let call them theta, are substituted in some position of args, meaning that the computation of the gradient depends also on the entire simulated trajectory.
Let be something like this:
The gradient can be computed sice the result from diffrax is differentiable. However, the computation of the gradient is really slow, and it can also costs more time than the actual simulation. I understand that it is not simple to backpropagate through the entire simulation trajectory and more. Despite this, I'm trying to find an approach to make the gradient computation faster.
Do you have any suggestions?
Sorry for the long description, but I thought it was necessary to better understand what's going on.
Thank for this beautiful framework you are working on!!!!