trixi-framework / Trixi.jl

Trixi.jl: Adaptive high-order numerical simulations of conservation laws in Julia
https://trixi-framework.github.io/Trixi.jl
MIT License
505 stars 98 forks source link

Make custom integrators more similar to OrdinaryDiffEq.jl integrators? #1886

Closed DanielDoehring closed 2 weeks ago

DanielDoehring commented 3 months ago

While looking at the elixir_euleracoustics_co-rotating_vortex_pair.jl I noticed that this can only be run with the integrators from OrdinaryDiffEq.jl because the EulerAcousticsCouplingCallback requires a step! function

https://github.com/trixi-framework/Trixi.jl/blob/18aaae96035a995e840e4e262964e7b49fdd9325/src/callbacks_step/euler_acoustics_coupling.jl#L189 and an init function

https://github.com/trixi-framework/Trixi.jl/blob/18aaae96035a995e840e4e262964e7b49fdd9325/src/callbacks_step/euler_acoustics_coupling.jl#L129

which are not provided by the existing implementations.

We could, however, add these functions relatively easy, exemplified by SimpleIntegrator2N:

Essentially, we would need to provide a init function

function init(ode::ODEProblem, alg::T;
               dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm2N}
    u = copy(ode.u0)
    du = similar(u)
    u_tmp = similar(u)
    t = first(ode.tspan)
    iter = 0
    integrator = SimpleIntegrator2N(u, du, u_tmp, t, dt, zero(dt), iter, ode.p,
                                    (prob = ode,), ode.f, alg,
                                    SimpleIntegrator2NOptions(callback, ode.tspan;
                                                              kwargs...), false)

    # initialize callbacks
    if callback isa CallbackSet
        foreach(callback.continuous_callbacks) do cb
            error("unsupported")
        end
        foreach(callback.discrete_callbacks) do cb
            cb.initialize(cb, integrator.u, integrator.t, integrator)
        end
    elseif !isnothing(callback)
        error("unsupported")
    end

    return integrator
end

that is essentially the current solve function with only the last line changed

https://github.com/trixi-framework/Trixi.jl/blob/f10969548615547e520623a9fb351a41bd952065/src/time_integration/methods_2N.jl#L108-L133

Then, the step! function could be implemented as

function step!(integrator::SimpleIntegrator2N)
    @unpack prob = integrator.sol
    @unpack alg = integrator
    t_end = last(prob.tspan)
    callbacks = integrator.opts.callback

    @assert !integrator.finalstep
    if isnan(integrator.dt)
        error("time step size `dt` is NaN")
    end

    # if the next iteration would push the simulation beyond the end time, set dt accordingly
    if integrator.t + integrator.dt > t_end ||
       isapprox(integrator.t + integrator.dt, t_end)
        integrator.dt = t_end - integrator.t
        terminate!(integrator)
    end

    # one time step
    integrator.u_tmp .= 0
    for stage in eachindex(alg.c)
        t_stage = integrator.t + integrator.dt * alg.c[stage]
        integrator.f(integrator.du, integrator.u, prob.p, t_stage)

        a_stage = alg.a[stage]
        b_stage_dt = alg.b[stage] * integrator.dt
        @trixi_timeit timer() "Runge-Kutta step" begin
            @threaded for i in eachindex(integrator.u)
                integrator.u_tmp[i] = integrator.du[i] -
                                      integrator.u_tmp[i] * a_stage
                integrator.u[i] += integrator.u_tmp[i] * b_stage_dt
            end
        end
    end
    integrator.iter += 1
    integrator.t += integrator.dt

    # handle callbacks
    if callbacks isa CallbackSet
        foreach(callbacks.discrete_callbacks) do cb
            if cb.condition(integrator.u, integrator.t, integrator)
                cb.affect!(integrator)
            end
            return nothing
        end
    end

    # respect maximum number of iterations
    if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
        @warn "Interrupted. Larger maxiters is needed."
        terminate!(integrator)
    end
end

which is almost identical to the current solve! function

https://github.com/trixi-framework/Trixi.jl/blob/f10969548615547e520623a9fb351a41bd952065/src/time_integration/methods_2N.jl#L135-L193

For the version with init and step one could then implement solve as

function solve(ode::ODEProblem, alg::T;
               dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm2N}

integrator = init(ode, alg, dt, callbkck, kwargs...)

@unpack prob = integrator.sol

integrator.finalstep = false

@trixi_timeit timer() "main loop" while !integrator.finalstep
  step!(integrator)
end # "main loop" timer

return TimeIntegratorSolution((first(prob.tspan), integrator.t),
                              (prob.u0, integrator.u),
                              integrator.sol.prob)
end

which behaves as before.

ranocha commented 3 months ago

I originally implemented it like it is right now to simplify it as much as possible while keeping the option to use the same functionality we need from integrators provided by OrdinaryDiffEq.jl. I would be fine with these changes but @sloede needs to agree as well (since it makes the implementation more complex).

To be able to use custom integrators also for the cases you mentioned, we would need to specialize init, solve! etc. from https://github.com/SciML/CommonSolve.jl. This will be a real change since we use Trixi.solve right now with custom time integrators instead of the common solve version.

DanielDoehring commented 3 months ago

To be able to use custom integrators also for the cases you mentioned, we would need to specialize init, solve! etc. from https://github.com/SciML/CommonSolve.jl. This will be a real change since we use Trixi.solve right now with custom time integrators instead of the common solve version.

I think if we would really want to use solve from OrdinaryDiffEq.jl a lot more would have to be implemented which I think is unnecessary (at least at the moment). Thus I would stick to the Trixi.solve with the presented implementation.

ranocha commented 3 months ago

But we would also need the same step! function as OrdinaryDiffEq.jl without depending on OrdinaryDiffEq.jl - or some special handling in the functions where we use it

sloede commented 3 months ago

IIUC, you do not want to make the implementation much more complicated, but just refactor solve and solve! into init and step!, right? Thus, if you were to implement it as described above, maybe with one or two additional in-source comments that make it easier to understand for novices, I wouldn't be opposed.

To some extent, ime integration is black magic anyways, and not that many people need to deal with its nitty gritty details except when they need something special - and in that case, more modularity is probably helpful

DanielDoehring commented 3 months ago

IIUC, you do not want to make the implementation much more complicated, but just refactor solve and solve! into init and step!, right?

Yes that is right!

Thus, if you were to implement it as described above, maybe with one or two additional in-source comments that make it easier to understand for novices, I wouldn't be opposed.

I actually think that the more explicit version could even be helpful in illustrating that not the ODE-Algorithm, but actually the ODE-Integrator actually solves the problem