sammccallum / reversible

JAX implementation of the Reversible Solver method
https://arxiv.org/abs/2410.11648
0 stars 1 forks source link

Solution including trajectory at intermediate times #1

Open adam-hartshorne opened 4 days ago

adam-hartshorne commented 4 days ago

As far as I can tell at the moment, the solution just returns the solution at t=T. Is it possible to return a trajectory of values at intermediate times? e.g. https://docs.kidger.site/diffrax/api/saveat/ ..and will it cause a significant slow down (as it does in Diffrax).

sammccallum commented 4 days ago

Hi Adam, thanks for the question.

Yep, you are correct. Currently the solution is returned at $t=T$ as the most common workflow for Neural ODEs is to define a loss $L(y(T))$ dependent on the terminal state.

Alternatively, if your loss function depends on the full trajectory $L(y(t))$ then it is possible to calculate the loss as an integral alongside the ODE solve (without saving the full trajectory). This was done for the CNF example in the paper and would have improved memory-efficiency over storing the trajectory.

Anyway, it would be possible to add the option to return the full trajectory. This would require tweaking the backpropagation algorithm: we would need to update the intermediate solution tangents $\bar{y}_n$ at each time-step $n$ rather than calculating from zero.

As for Diffrax, I believe the solution is interpolated over the (adaptive) time-steps taken and evaluates at the user specified times, which may explain the slow down. If the step-sizes are fixed then saving at intermediate times would just correspond to filling an array of pre-specified shape which wouldn't be expensive.

adam-hartshorne commented 2 days ago

My use case might be quite niche, but I need to know the locations along the trajectory at predefined intervals to then be used in a number of different loss functions.

I currently use a probabilistic ode solver (https://github.com/pnkraemer/probdiffeq) both because I can output the trajectory efficiently but also in some use cases I wish to model the problem to incorporate uncertainty both from the solve and the data. However, in other use cases, I don't need any UQ and am just looking for the most efficient solver.