SciML / DiffEqDocs.jl

Documentation for the DiffEq differential equations and scientific machine learning (SciML) ecosystem
https://docs.sciml.ai/DiffEqDocs/stable/
Other
264 stars 238 forks source link

Rewinding integrator #704

Open lgravina1997 opened 8 months ago

lgravina1997 commented 8 months ago

I am looking for a way to rewind the integrator to the state it had in a previous time. Specifically, I want to use a continuous callback to monitor some condition. When this condition is satisfied I want to go back Delta_t in time, modify the solution in that point, and start the integration back from there.

Is there a way to do so?

ChrisRackauckas commented 8 months ago

You can use reinit!

lgravina1997 commented 7 months ago

Indeed reinit! serves the purpose but it has one problem that might be an interesting issue.

Consider the following simple problem:

u0 = 1
tl = LinRange(0,1,1001)
tspan = (tl[1], tl[end])
prob  = ODEProblem(f, u0, tspan)
sol     = solve(prob, Tsit5(), saveat=tl)

Now assume at t=t0=0.5 the solution is changed. We can do this with a callback:

t0=0.5
condition(u, t, integrator) = t==t0
affect!(integrator) = integrator.u = 1

cb = DiscreteCallback(condition, affect!, save_positions=(false,false))
prob = ODEProblem(f, u0, tspan)
sol    = solve(prob, Tsit5(), callback=cb, tstops=[t0,], saveat=tl)

Assume at t=t1=0.75 we discover the solution is wrong and want to go back to t=t0 and perform the correct evolution. We can do:

t1 = 0.75
u_t0 = sol.u[argmin(abs.(tl-t1))] #solution at t0

condition_1(u, t, integrator) = (t==t1) && (integrator.u > 1.2*u_t0)

function affect_1!(integrator)
    integrator.u = u_t0
    integrator.t = t0
    reinit!(integrator, integrator.u; t0=t0, erase_sol=false)
end

cb    = DiscreteCallback(condition, affect!, save_positions=(false,false))
cb1  = DiscreteCallback(condition_1, affect_1!, save_positions=(false,false))
sol1 = solve(prob, Tsit5(), callback=CallbackSet(cb, cb1), tstops=[t0,t1], saveat=tl)

The problem is the following:

It would be worth with the erase_sol=true function only erasing up to the time t0 chosen with the kwarg.

Below an example of the manifestation of the problem in the case erase_sol=false

Screenshot 2023-11-07 at 15 40 03

lgravina1997 commented 7 months ago

I tried solving by using

function affect_1!(integrator)
    integrator.u = sol.u[select(t0, tl)]
    integrator.t = t0
    reinit!(integrator, integrator.u; t0=t0, erase_sol=false)

    idx = findlast(integrator.sol.t .<= t0)
    resize!(integrator.sol.t, idx)
    resize!(integrator.sol.u, idx)
end

but this gives the error (occurring at the end of the integration, not at the moment of resizing)

BoundsError: attempt to access 1001-element Vector{Float64} at index [1251]

Stacktrace:
  [1] getindex
    @ [./essentials.jl:13](https://vscode-remote+ssh-002dremote-002b128-002e178-002e67-002e106.vscode-resource.vscode-cdn.net/home/lgravina/phd_codes/LowRank/essentials.jl:13) [inlined]
  [2] solution_endpoint_match_cur_integrator!(integrator::OrdinaryDiffEq.OD ...
  [3] _postamble!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), ...
ChrisRackauckas commented 7 months ago

You'll also need to reset saveiter too. This is hitting internals so it's not going to be the most robust, though indexing the save via saveiter hasn't changed in 6 years so in theory it's fine though in practice it's hitting an internal non-public API to do this.

lgravina1997 commented 7 months ago

This is perfect. Indeed it works. The only problem now is that if one has a PresetTimeCallback this interferes with the reset of the integrator. Specifically, if I take again the example from before


t1 = 0.75
u_t0 = sol.u[argmin(abs.(tl-t1))] #solution at t0

condition_1(u, t, integrator) = (t==t1) && (integrator.u > 1.2*u_t0)

function affect_1!(integrator)
    integrator.u = sol.u[select(t0, tl)]
    reinit!(integrator, integrator.u; t0=t0, erase_sol=false)

    idx = findlast(integrator.sol.t .<= t0)
    resize!(integrator.sol.t, idx)
    resize!(integrator.sol.u, idx)
    integrator.saveiter = idx
end

cb    = DiscreteCallback(condition, affect!, save_positions=(false,false))
cb1  = DiscreteCallback(condition_1, affect_1!, save_positions=(false,false))
cb2  = PresetTimeCallback(tl, x->x, save_positions=(false,false))
sol1 = solve(prob, Tsit5(), callback=CallbackSet(cb, cb1, cb2), tstops=[t0,t1], saveat=tl)

This code would work perfectly in the absence of this new PresetTimeCallback cb2 that I added. When I include this I get Tried to add a tstop that is behind the current time. This is strictly forbidden. This comes from the initialisation of cb2 that likely takes place after the reinitialisation of the integrator.

Interestingly, not using cb2 but including tstops=vcat(tl,[t0, t1]), i.e.

sol1 = solve(prob, Tsit5(), callback=CallbackSet(cb, cb1), tstops=vcat(tl,[t0, t1]), saveat=tl);

does not give this problem. Is this something that should be fixed in the PresetTimeCallback or am I mistaking?

ChrisRackauckas commented 7 months ago

I'd need an MWE for this.