SciML / OrdinaryDiffEq.jl

High performance ordinary differential equation (ODE) and differential-algebraic equation (DAE) solvers, including neural ordinary differential equations (neural ODEs) and scientific machine learning (SciML)
https://diffeq.sciml.ai/latest/
Other
536 stars 205 forks source link

solution after first DiscreteCallback not saved immediately #939

Open goretkin opened 5 years ago

goretkin commented 5 years ago

I expected the first two steps in the solution to be immediately before and immediately after the first callback event, but it seems like the callback event after is not saved.

They are saved for subsequent events, however. Is this the intended behavior?

The callback is definitely called at t_span[1], from what I can tell.

import OrdinaryDiffEq: DEDataArray, ODEProblem, solve, full_cache
import OrdinaryDiffEq
import DiffEqCallbacks: PeriodicCallback

mutable struct SimulationState{T, X, U, S} <: DEDataArray{T, 1}
    x::X        # ODE State Space
    # the following are just to store quantities during integration. The ODE function should not read from these fields.
    control::U  # ODE input, controller output storage.
    time::S     # for debugging
    control_cycle::Int64
end

# SimulationState is supposed to have an eltype because it obeys the AbstractArray interface
SimulationState(x, control, time, control_cycle) = SimulationState{eltype(x), typeof(x), typeof(control), typeof(time)}(x, control, time, control_cycle)

struct SimulationClosure{TP, TC}
    parameters::TP
    control::TC
end

function make_hybrid(continuous_plant!, discrete_controller, u_0, t_span, Δt, parameters)
    # controller happens in discrete time. physics happens in continuous time
    closure = SimulationClosure(
        parameters,
        Ref{Any}()
    )

    function compute_control!(sim_state, closure::SimulationClosure, time)
        action = discrete_controller(sim_state, closure.parameters, time)
        closure.control[] = action
        # TODO this line makes this function not idempotent. A true/false flag would have.
        # Could set to true here and false at the continuous steps.
        sim_state.control_cycle += 1
    end

    function cb!(integrator)
        time = integrator.t
        simulation_state = integrator.u
        compute_control!(simulation_state, closure, time)
        simulation_state.time = time
    end

    # PeriodicCallback does not modify state (of continuous system)
    # so there is no need to save more than once at the same time instant.
    # save_positions=(false, false) seems to accomplish this.
    pcb = PeriodicCallback(cb!, Δt, save_positions=(true, true))

    ode_prob = ODEProblem(
        continuous_plant!,
        u_0,
        t_span,
        closure;
        callback=OrdinaryDiffEq.CallbackSet(pcb)
    )
    return ode_prob
end

function plant_dynamics!(dstate, sim_state, closure, time)
    thrust = closure.control[]
    (s, v) = (sim_state[1], sim_state[2])

    acceleration = thrust / closure.parameters.mass + closure.parameters.damping * v

    dstate .= [sim_state[2], acceleration]  # TODO repack
    sim_state.control = thrust
end

function controller(sim_state, parameters, times)
    (s, v) = (sim_state[1], sim_state[2])
    return parameters.k * s
end

t_span = (0.0, 10.0)
Δt = 1.0
parameters = (mass=1, k=-0.1, damping=-0.1)
u_0 = SimulationState([10.0, 0.0], NaN, NaN, 0) #NaNs just to make sure initial condition is not being touched

ode_problem = make_hybrid(plant_dynamics!, controller, u_0, t_span, Δt, parameters)
solution = solve(ode_problem, OrdinaryDiffEq.RK4())

# should be the initial condition, immediately before the first PeriodicCallback
@show solution.u[1].time === NaN
@show solution.u[1] == u_0
@show solution.u[1].control_cycle == 0
@show solution.t[1] == 0.0

# should be immediately after the firstPeriodicCallback
@show solution.u[2].time == 0.0
@show solution.u[2].control_cycle == 1
println("expect 0.0:")
@show solution.t[2]

t_0 = t_span[1]
t_1 = t_0 + Δt
i = searchsortedfirst(solution.t, t_1)

@show solution.t[i] == t_1
@show solution.t[i+1] == t_1

@show solution.u[i].time == t_0
@show solution.u[i+1].time == t_1

@show solution.u[i].control_cycle == 1
@show solution.u[i+1].control_cycle == 2

@show solution.u[i] == solution.u[i+1]

output:

(solution.u[1]).time === NaN = true
solution.u[1] == u_0 = true
(solution.u[1]).control_cycle == 0 = true
solution.t[1] == 0.0 = true
(solution.u[2]).time == 0.0 = true
(solution.u[2]).control_cycle == 1 = true
expect 0.0:
solution.t[2] = 0.0009999000099990003
solution.t[i] == t_1 = true
solution.t[i + 1] == t_1 = true
(solution.u[i]).time == t_0 = true
(solution.u[i + 1]).time == t_1 = true
(solution.u[i]).control_cycle == 1 = true
(solution.u[i + 1]).control_cycle == 2 = true
solution.u[i] == solution.u[i + 1] = true
goretkin commented 5 years ago

I see, I think this is because the first callback is called not via the tstop mechanism, but in a special-purpose way?

goretkin commented 5 years ago

I was trying to see a way to have only one place where the callback is called, which would involve having a tstop at t_span[1] (e.g. 0), but this skips the call to loopfooter!, where the callbacks are handled.

This kind of fixes the issue for me, except now the initial condition is repeated twice in the solution.

diff --git a/DiffEqCallbacks/src/iterative_and_periodic.jl b/DiffEqCallbacks/src/iterative_and_periodic.jl
index 4a3ada5..b547276 100644
--- a/src/iterative_and_periodic.jl
+++ b/src/iterative_and_periodic.jl
@@ -84,7 +84,7 @@ function PeriodicCallback(f, Δt::Number; initialize = DiffEqBase.INITIALIZE_DEF
         initialize(c, u, t, integrator)
         if initial_affect
             tnext[] = t
-            affect!(integrator)
+            add_tstop!(integrator, tnext[])
         else
             tnext[] = t + Δt
             add_tstop!(integrator, tnext[])
diff --git a/src/solve.jl b/src/solve.jl
index 84565cba..5f54695b 100644
--- a/src/solve.jl
+++ b/src/solve.jl
@@ -364,6 +364,7 @@ end

 function DiffEqBase.solve!(integrator::ODEIntegrator)
   @inbounds while !isempty(integrator.opts.tstops)
+    steps_performed_until_tstop = 0
     while integrator.tdir * integrator.t < top(integrator.opts.tstops)
       loopheader!(integrator)
       if check_error!(integrator) != :Success
@@ -374,8 +375,12 @@ function DiffEqBase.solve!(integrator::ODEIntegrator)
       if isempty(integrator.opts.tstops)
         break
       end
+      steps_performed_until_tstop += 1
     end
     handle_tstop!(integrator)
+    if steps_performed_until_tstop == 0
+      handle_callbacks!(integrator) # TODO really only need to handle discrete
+    end
   end
   postamble!(integrator)
ChrisRackauckas commented 5 years ago

The solve! change shouldn't be needed. Initialization isn't the same as calling the callback. Initialization of many callbacks is very different than the affect!. PeriodicCallback should just call the affect at the start.

What's the issue here?

goretkin commented 5 years ago

Right, so PeriodicCallback does appear to call affect! at tspan[1], and I was expecting the solution to have the values immediately before (the initial condition) and immediately after the callback.

But that doesn't appear to be the case. To highlight what was buried in my initial message: solution.t[2] = 0.0009999000099990003 where I expected solution.t[2] = 0.0

ChrisRackauckas commented 5 years ago

Interesting. Thanks, I'll have to look into it. I have an awfully crazy month coming up, so feel free to ping me sometime early November to remind me about this.

miromarszal commented 4 years ago

I incidentally stumbled upon this or very similar issue, or it's just my poor understanding of callbacks. I'm using a PresetTimeCallback to modify the solution at times including t=0. I noticed that at this particular time point the callback is definitely called, but it has no effect on the integrator, unless I include u_modified!(integrator, true) in the function. At all the other times, the callback works as expected without this line. Is this intended behaviour? After all, I can apply my modification to u0 before calling the solver, but I thought it would be more consistent this way.

ChrisRackauckas commented 4 years ago

unless I include u_modified!(integrator, true) in the function. At all the other times, the callback works as expected without this line. Is this intended behaviour?

That's not intended, and it points exactly to what the bug probably is. Thanks!