SciML / DiffEqGPU.jl

GPU-acceleration routines for DifferentialEquations.jl and the broader SciML scientific machine learning ecosystem
https://docs.sciml.ai/DiffEqGPU/stable/
MIT License
279 stars 29 forks source link

Discrete Callbacks for GPUTsit5 #174

Closed utkarsh530 closed 1 year ago

utkarsh530 commented 2 years ago

Hi,

I have started to add a rudimentary implementation of discrete callbacks. It works for now but needs more work for better support.

This currently works in the un-adaptive version. It does not push the tstops into the ts time-series, because its size is pre-determined. I think CuMatrix{typeof(dt)}(undef, len(original_ts) + len(tstops), length(probs)) could work, but it may have duplicates.

I'll proceed for more robustness and alignment toward callbacks implemented in DiffEqBase. I believe correctly modifying the DiffEqBase callback functions would solve these problems.

utkarsh530 commented 2 years ago

MWE:

using DiffEqGPU, SimpleDiffEq, StaticArrays, CUDA, BenchmarkTools, OrdinaryDiffEq
using Plots

CUDA.allowscalar(false)

function f(u, p, t)
    du1 = -u[1]
    return SVector{1}(du1)
end

u0 = @SVector [10.0f0]
prob = ODEProblem{false}(f, u0, (0.0f0, 10.0f0))
prob_func = (prob, i, repeat) -> remake(prob, p = prob.p)
monteprob = EnsembleProblem(prob, safetycopy = false)
const V = 1

condition(u, t, integrator) = t == 4.0f0
affect!(integrator) = integrator.u += @SVector[10.0f0]
cb = GPUDiscreteCallback(condition, affect!)

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(),
            trajectories = 2,
            adaptive = false, dt = 0.1f0, callback = cb, merge_callbacks = true,
            tstops = CuArray([4.0f0]))

plot(sol[1])
Screenshot 2022-08-16 at 10 59 44 PM
utkarsh530 commented 1 year ago

@ChrisRackauckas, I wanted to confirm whether this behavior is correct in OrdinaryDiffEq:

julia> bench_sol = solve(prob, Tsit5(),
                   adaptive = true, dt = 0.01f0, callback = cb, merge_callbacks = true,
                   tstops = [4.0f0], saveat = [0.f0,4.0f0])
retcode: Success
Interpolation: 1st order linear
t: 3-element Vector{Float32}:
 0.0
 4.0
 4.0
u: 3-element Vector{SVector{1, Float32}}:
 [10.0]
 [0.18316455]
 [10.183165]

I think this probably happens in savevalues! when something like push!(integrator.sol.u,integ.u);push!(integrator.sol.t,integ.t) is called.

ChrisRackauckas commented 1 year ago

Yes, it's required for left and right continuity

utkarsh530 commented 1 year ago

Yes, I made the changes. Now it is probably correct.

julia> @test norm(bench_sol(4.0f0) - sol[1](4.0f0)) < 1e-6
Test Passed
  Expression: norm(bench_sol(4.0f0) - (sol[1])(4.0f0)) < 1.0e-6
   Evaluated: 1.7881393f-7 < 1.0e-6

julia> @test norm(bench_sol.u - sol[1].u) < 1e-6
Test Passed
  Expression: norm(bench_sol.u - (sol[1]).u) < 1.0e-6
   Evaluated: 0.0f0 < 1.0e-6
utkarsh530 commented 1 year ago

@ChrisRackauckas, please review.

ChrisRackauckas commented 1 year ago

This looks good.