SciML / StochasticDiffEq.jl

Solvers for stochastic differential equations which connect with the scientific machine learning (SciML) ecosystem
Other
237 stars 65 forks source link

Inplace `RODESolution` interpolation has unexpected errors #538

Closed brainsMAKER closed 11 months ago

brainsMAKER commented 12 months ago

Issue description

SDESolution interpolation in inplace form sol(vals, tvals; idxs=...) sometimes errors out and sometimes doesn't, unexpectedly depending on the type of vals (Vector{Vector{Float64}} or Vector{Float64}).

I believe the current behaviour does not match expectations, some errors seem to be unnecessary.

Related issue : SciML/OrdinaryDiffEq.jl#1991,

Working example

This is a non-minimal working example:

using StochasticDiffEq, Test

##################################################
# Definitions

brk2 = true     # StochasticDiffEq.sde_interpolation! issues

if true         # to prevent unnecesary compilation

    # utils
    "Creates a m-Vector of n-Arrays"
    vecveczero(m::Integer, n) = map(t->zeros(n), 1:m)
    vecveczero(::Type{T}, m::Integer, n) where {T} = map(t->zeros(T, n), 1:m)
    VF = Vector{Float64}
    VVF = Vector{Vector{Float64}}

    # problem definitions
    tspan = (0.0, 4.0)

    ftest1D(u,p,t) = u+20cos(5t)
    σtest1D(u,p,t) = 0.2u

    ftest3D!(du,u,p,t) = @. (du = (-1,0.5,1)*u + 20*cos(5t))
    σtest3D!(dσ,u,p,t) = setindex!(dσ, 0.2*u[3], 3)         #diagonal noise

    prb_SDE_1D = SDEProblem(ftest1D, σtest1D, 1.0, tspan)
    prb_SDE_3D = SDEProblem(ftest3D!, σtest3D!, ones(3), tspan)

end

# interpolation definition
tt = 0:0.01:3
ntt = length(tt)
out_VF = zeros(ntt)
out_VVF_1 = vecveczero(ntt,1)
out_VVF_2 = vecveczero(ntt,2)

except2 = brk2 ? Exception : MethodError

##################################################
# Testing SDESolution

printstyled("\nTests for SDESolution inplace interpolation:\n"; color=:red, bold=true)
@testset verbose=true "SDESolution" begin

    # solving both
    sol_SDE_1D = solve(prb_SDE_1D, SMEA(); dt=0.1)
    sol_SDE_3D = solve(prb_SDE_3D, SMEA(); dt=0.1)

    # 1D solution interpolation
    printstyled("1D :\n"; color=:light_green)
    @testset "1D" begin
        print("zeros\t\t\t")
        # Next should pass  → ok
        @test @time(sol_SDE_1D(out_VF, tt) isa VF)
        # Next should error → ok (getindex)
        @test_throws except2 sol_SDE_1D(out_VF, tt; idxs=1:1)
        # Next should error → ok
        brk2 || print("vecvec\t\t\t")
        # Next should pass  !!
        @test @time(sol_SDE_1D(out_VVF_1,tt) isa VVF)                   broken=brk2
        # If interpolations succeded, they should be equal  → ok
        brk2 || @test out_VF == first.(out_VVF_1)
    end

    # 3D solution interpolation
    printstyled("3D :\n"; color=:light_green)
    @testset "3D" begin
        brk2 || printstyled("zeros, idxs=3\t\t"; bold=true)
        # Next should pass  !!
        @test @time(sol_SDE_3D(out_VF, tt; idxs=3) isa VF)              broken=brk2
        # Next should error → ok (setindex!)
        @test_throws except2 sol_SDE_3D(out_VF, tt; idxs=3:3)
        printstyled("vecvec, idxs=3\t\t"; bold=true)
        # Next should pass  → ok
        @test @time(sol_SDE_3D(out_VVF_1, tt; idxs=3) isa VVF)
        print("vecvec, idxs=3:3\t")
        # Next should pass  → ok
        @test @time(sol_SDE_3D(out_VVF_1, tt; idxs=3:3) isa VVF)
        print("vecvec, idxs=2:3\t")
        # Next should pass  → ok
        @test @time(sol_SDE_3D(out_VVF_2, tt; idxs=2:3) isa VVF)
        # If interpolations succeded, they should be equal  → ok
        @test first.(out_VVF_1) == last.(out_VVF_2)
        brk2 || @test out_VF == first.(out_VVF_1)
    end

end

Here is the screenshot of the broken result in 6.61.1 : issue2_broken

Proposed fix

The problem is located in sde_interpolation!.

The errors come from trying to convert Vector{Float64} into Float64 and the reverse because of wrongly calling sde_interpolant instead of sde_interpolant!. The fix consists in checking wether eltype(vals) isa AbstractArray and not eltype(timeseries) some other check.

Here is the screenshot of the fixed result (incoming PR): issue2_fixed