SciML / StochasticDiffEq.jl

Solvers for stochastic differential equations which connect with the scientific machine learning (SciML) ecosystem
Other
248 stars 66 forks source link

Non-diagonal noise SDEs with non-vector variables #472

Open tbilitewski opened 2 years ago

tbilitewski commented 2 years ago

Currently, the SDE interface only allows vectors for variables, and matrices for the noise_rate_prototype.

Please find a minimal example failing below

using DifferentialEquations

f(du,u,p,t) = du .= 1.01u
function g(du,u,p,t)
  du[1,1,1] = 1.0
  du[1,2,1] = 1.0

  du[2,1,2] = 1.0
  du[2,2,2] = 1.0

  du[3,1,3] = 1.0
  du[3,2,3] = 1.0
end
prob = SDEProblem(f,g,ones(3,2),(0.0,1.0),noise_rate_prototype=zeros(3,2,3))

sol = solve(prob)

with the stack trace

BoundsError: attempt to access 3×2×3 Array{Float64, 3} at index [1, 1:2]

Stacktrace:
  [1] throw_boundserror(A::Array{Float64, 3}, I::Tuple{Int64, Base.Slice{Base.OneTo{Int64}}})
    @ Base ./abstractarray.jl:691
  [2] checkbounds
    @ ./abstractarray.jl:656 [inlined]
  [3] _getindex
    @ ./multidimensional.jl:838 [inlined]
  [4] getindex
    @ ./abstractarray.jl:1218 [inlined]
  [5] __init(_prob::SDEProblem{Matrix{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, Nothing, SDEFunction{true, typeof(f), typeof(g), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, typeof(g), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MyArray{Float64, 3, Array{Float64, 3}}}, alg::LambaEM{true}, timeseries_init::Vector{Any}, ts_init::Vector{Any}, ks_init::Type, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_noise::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmin::Rational{Int64}, qmax::Rational{Int64}, qsteady_min::Int64, qsteady_max::Int64, beta2::Nothing, beta1::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, delta::Rational{Int64}, maxiters::Int64, dtmax::Float64, dtmin::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, force_dtmin::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, initialize_integrator::Bool, seed::UInt64, alias_u0::Bool, alias_jumps::Bool, kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol, Symbol}, NamedTuple{(:default_set, :second_time), Tuple{Bool, Bool}}})
    @ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/KVp0Z/src/solve.jl:296
  [6] #__solve#110
    @ ~/.julia/packages/StochasticDiffEq/KVp0Z/src/solve.jl:6 [inlined]
  [7] __solve(::SDEProblem{Matrix{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, Nothing, SDEFunction{true, typeof(f), typeof(g), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, typeof(g), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MyArray{Float64, 3, Array{Float64, 3}}}, ::Nothing; default_set::Bool, kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:second_time,), Tuple{Bool}}})
    @ DifferentialEquations ~/.julia/packages/DifferentialEquations/4jfQK/src/default_solve.jl:8
  [8] #__solve#59
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:397 [inlined]
  [9] __solve
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:384 [inlined]
 [10] #solve_call#39
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:152 [inlined]
 [11] solve_call
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:139 [inlined]
 [12] #solve_up#41
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:182 [inlined]
 [13] solve_up
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:170 [inlined]
 [14] #solve#40
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:165 [inlined]
 [15] solve(::SDEProblem{Matrix{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, Nothing, SDEFunction{true, typeof(f), typeof(g), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, typeof(g), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MyArray{Float64, 3, Array{Float64, 3}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:159
 [16] top-level scope
    @ In[27]:14
 [17] eval
    @ ./boot.jl:373 [inlined]
 [18] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1196

I should mention that it fails in the same way even if using a custom matrix type that properly defines g * dW.

Based on the documentation

noise_rate_prototype: A prototype type instance for the noise rates, that is the output g. It can be any type which overloads A_mul_B! with itself being the middle argument. Commonly, this is a matrix or sparse matrix. If this is not given, it defaults to nothing, which means the problem should be interpreted as having diagonal noise.

I had hoped that overloading mul! in the following might have worked.

using LinearAlgebra

struct MyArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
    inner::A
end

Base.size(A::MyArray) = (size(A.inner,1),size(A.inner,2),size(A.inner,3))
Base.getindex(A, I::Vararg{Int, 3}) = getindex(A.inner,I)
Base.getindex(A::AbstractArray, i::Int64, j::Int64, k::Int64) = getindex(A.inner,i,j,k) 
Base.setindex!(A, v, I::Vararg{Int, 3}) = setindex(A.inner,I)
Base.setindex(A::AbstractArray, i::Int64, j::Int64, k::Int64) = setindex(A.inner,i,j,k) 

function LinearAlgebra.mul!(k_::AbstractArray, c::MyArray, k::AbstractVector)
    Base.require_one_based_indexing(k)
    @assert length(k) == size(c, 3)

    n1 = size(c,1)
    n2 = size(c,2)
    n3 = size(c,3)

    k_ .= reshape(reshape(c,(n1*n2,n3)) * k,(n1,n2))

    return k_
end

prob = SDEProblem(f,g,ones(3,2),(0.0,1.0),noise_rate_prototype=MyArray(zeros(3,2,3)))

sol = solve(prob)

but it also seems to fail during initialisation.

ChrisRackauckas commented 2 years ago

Xref https://discourse.julialang.org/t/non-diagonal-sde-with-matrix-variables/82351

Its definition is g(u,p,t)*dW, so dW is always a vector. So then noise_rate_prototype always needs to be a matrix. That means f has to be vector defined in any case of non-diagonal noise. I think we can generalize it, but that's how the linear algebra works out today.