SciML / DiffEqCallbacks.jl

A library of useful callbacks for hybrid scientific machine learning (SciML) with augmented differential equation solvers
https://docs.sciml.ai/DiffEqCallbacks/stable/
Other
85 stars 44 forks source link

Offset in PeriodicCallback #192

Closed vbertret closed 2 months ago

vbertret commented 6 months ago

Is your feature request related to a problem? Please describe.

Hello, I am currently using the PeriodicCallback to model a periodic series of 1 and 0 values. The series of 1 and 0 do not have the same period. For instance, the period of the 1 series is 60 minutes, while the period of the 0 series is 20 minutes. Using this example, I aim to have a sequence of 1 for the first 60 minutes, followed by 0 for 20 minutes, then 1 for the next 60 minutes, 0 for the next 20 minutes and so on. I use the PeriodicCallback to update a value within my system. However, with the current callback implementation, achieving this behavior is not straightforward because it doesn't work at the beginning of the specified time span.

Describe the solution you’d like

I would like to propose adding an offset to the PeriodicCallback so that the first action occurs at tspan[1] + offset instead of tspan[1]. Currently, I've implemented my own OffsetPeriodicCallback (essentially a modified copy of PeriodicCallback with an added offset parameter) in my library. However, it would be more convenient if this functionality were directly integrated into the existing PeriodicCallback.

Describe alternatives you’ve considered

Below is the code I have implemented to make this work:

using Parameters

struct PeriodicCallbackAffect{A, dT, Ref1, Ref2}
    affect!::A
    Δt::dT
    t0::Ref1
    index::Ref2
end

function (S::PeriodicCallbackAffect)(integrator)
    @unpack affect!, Δt, t0, index = S

    affect!(integrator)

    tstops = integrator.opts.tstops

    # Schedule next call to `f` using `add_tstops!`, but be careful not to keep integrating forever
    tnew = t0[] + (index[] + 1) * Δt
    tstops = integrator.opts.tstops
    #=
    Okay yeah, this is nasty
    the comparer is always less than for type stability, so in order
    for this to actually check the correct direction we multiply by
    tdir
    =#
    tdir_tnew = integrator.tdir * tnew
    for i in length(tstops):-1:1 # reverse iterate to encounter large elements earlier
        if tdir_tnew < tstops.valtree[i] # TODO: relying on implementation details
            index[] += 1
            add_tstop!(integrator, tnew)
            break
        end
    end
end

function PeriodicOffsetCallback(f, Δt::Number;
                          offset::Number = 0,
                          initial_affect = false,
                          initialize = (cb, u, t, integrator) -> u_modified!(integrator,
                                                                             initial_affect),
                          kwargs...)

    # Value of `t` at which `f` should be called next:
    t0 = Ref(typemax(Δt))
    index = Ref(0)
    condition = (u, t, integrator) -> t == (t0[] + index[] * Δt)

    # Call f, update tnext, and make sure we stop at the new tnext
    affect! = PeriodicCallbackAffect(f, Δt, t0, index)

    # Initialization: first call to `f` should be *before* any time steps have been taken:
    initialize_periodic = function (c, u, t, integrator)
        @assert integrator.tdir == sign(Δt)
        initialize(c, u, t, integrator)
        t0[] = t + offset
        if initial_affect
            index[] = 0
            affect!(integrator)
        else
            index[] = 1
            add_tstop!(integrator, t0[] + Δt)
        end
    end

    DiscreteCallback(condition, affect!; initialize = initialize_periodic, kwargs...)
end

Additional context

Thank you for considering this enhancement. If there are any corrections or improvements needed, please let me know.

ChrisRackauckas commented 6 months ago

An offset makes sense. I'd be happy to accept such a contribution