SciML / DiffEqGPU.jl

GPU-acceleration routines for DifferentialEquations.jl and the broader SciML scientific machine learning ecosystem
https://docs.sciml.ai/DiffEqGPU/stable/
MIT License
274 stars 28 forks source link

Add support for terminate! in EnsembleGPUKernel #199

Closed utkarsh530 closed 1 year ago

utkarsh530 commented 1 year ago

MWE:

using DiffEqGPU, OrdinaryDiffEq, StaticArrays, LinearAlgebra, CUDA

function f(u, p, t)
    du1 = -u[1]
    return SVector{1}(du1)
end

u0 = @SVector [10.0f0]
prob = ODEProblem{false}(f, u0, (0.0f0, 10.0f0))
prob_func = (prob, i, repeat) -> remake(prob, p = prob.p)
monteprob = EnsembleProblem(prob, safetycopy = false)

condition(u, t, integrator) = t == 2.40f0

function affect!(integrator)
    integrator.u += @SVector[10.0f0]
    terminate!(integrator)
end

cb = DiscreteCallback(condition, affect!; save_positions = (false, false))
sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(),
                trajectories = 2,
                adaptive = false, dt = 1.0f0, callback = cb, merge_callbacks = true,
                tstops = [2.40f0])

  bench_sol = solve(prob, Tsit5(),
                    adaptive = false, dt = 1.0f0, callback = cb, merge_callbacks = true,
                    tstops = [2.40f0])
julia> sol[1]
retcode: Default
Interpolation: 1st order linear
t: 12-element view(::Matrix{Float32}, :, 1) with eltype Float32:
 0.0
 1.0
 2.0
 2.4
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
u: 12-element view(::Matrix{SVector{1, Float32}}, :, 1) with eltype SVector{1, Float32}:
 [10.0]
 [3.6809897]
 [1.3549666]
 [0.90826195]
 [0.0]
 [0.0]
 [0.0]
 [0.0]
 [0.0]
 [0.0]
 [0.0]
 [0.0]
julia> bench_sol = solve(prob, Tsit5(),
                         adaptive = false, dt = 1.0f0, callback = cb, merge_callbacks = true,
                         tstops = [2.40f0])
retcode: Terminated
Interpolation: specialized 4th order "free" interpolation
t: 4-element Vector{Float32}:
 0.0
 1.0
 2.0
 2.4
u: 4-element Vector{SVector{1, Float32}}:
 [10.0]
 [3.6809988]
 [1.3549728]
 [0.90826607]
ChrisRackauckas commented 1 year ago

We just need to make sure that behavior is well-documented.

ChrisRackauckas commented 1 year ago

It would probably be good to have a post-process phase that subsets.

utkarsh530 commented 1 year ago

How should we do that? Return an index from the kernel or just select the last index, which is non-zero in sol.t?

ChrisRackauckas commented 1 year ago

Not non-zero but not the same as tspan[1]

codecov[bot] commented 1 year ago

Codecov Report

Merging #199 (9077c15) into master (1b67553) will not change coverage. The diff coverage is 0.00%.

@@          Coverage Diff           @@
##           master    #199   +/-   ##
======================================
  Coverage    0.00%   0.00%           
======================================
  Files           9       9           
  Lines        1977    1990   +13     
======================================
- Misses       1977    1990   +13     
Impacted Files Coverage Δ
src/DiffEqGPU.jl 0.00% <0.00%> (ø)
src/integrators/integrator_utils.jl 0.00% <0.00%> (ø)
src/integrators/types.jl 0.00% <ø> (ø)
src/perform_step/gpu_tsit5_perform_step.jl 0.00% <0.00%> (ø)
src/perform_step/gpu_vern7_perform_step.jl 0.00% <0.00%> (ø)
src/perform_step/gpu_vern9_perform_step.jl 0.00% <0.00%> (ø)
src/solve.jl 0.00% <0.00%> (ø)

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

utkarsh530 commented 1 year ago

The post-process does increase some allocations though:

Before:

julia> CUDA.@time sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(),
                       trajectories = 2,
                       adaptive = false, dt = 1.0f0, callback = cb, merge_callbacks = true,
                       tstops = [2.40f0])
  0.000236 seconds (147 CPU allocations: 5.891 KiB) (4 GPU allocations: 228 bytes, 8.57% memmgmt time)

After:

julia> CUDA.@time sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(),
                       trajectories = 2,
                       adaptive = false, dt = 1.0f0, callback = cb, merge_callbacks = true,
                       tstops = [2.40f0])
  0.000275 seconds (250 CPU allocations: 8.344 KiB) (4 GPU allocations: 228 bytes, 6.38% memmgmt time)
EnsembleSolution Solution of length 2 with uType:
ODESolution
ChrisRackauckas commented 1 year ago

What about using a view?

ChrisRackauckas commented 1 year ago

I think that could be worked on. May need to double check the higher level is type-stable.