SciML / StochasticDiffEq.jl

Solvers for stochastic differential equations which connect with the scientific machine learning (SciML) ecosystem
Other
248 stars 66 forks source link

Using StochasticDiffEq with custom types #576

Closed apkille closed 1 month ago

apkille commented 1 month ago

Describe the bug 🐞

Inability to solve SDEProblem with a custom array type (say CustomArray) that is not a subtype of AbstractArray, as DiffEqBase.__init requires that you define Base.:(/)(x::CustomArray, y::CustomArray).

Minimal Reproducible Example 👇

I am modifying the test example: https://github.com/SciML/StochasticDiffEq.jl/blob/master/test/noindex_tests.jl so that it the custom type is not a subtype of AbstractArray, which requires a few more methods to be defined:

using StochasticDiffEq, OrdinaryDiffEq, LinearAlgebra, RecursiveArrayTools

struct CustomArray{T, N}
    x::Array{T, N}
end
Base.size(x::CustomArray) = size(x.x)
Base.axes(x::CustomArray) = axes(x.x)
Base.ndims(x::CustomArray) = ndims(x.x)
Base.ndims(::Type{<:CustomArray{T,N}}) where {T,N} = N
Base.zero(x::CustomArray) = CustomArray(zero(x.x))
Base.zero(::Type{<:CustomArray{T,N}}) where {T,N} = CustomArray(zero(Array{T,N}))
Base.similar(x::CustomArray, dims::Union{Integer, AbstractUnitRange}...) = CustomArray(similar(x.x, dims...))
Base.copyto!(x::CustomArray, y::CustomArray) = CustomArray(copyto!(x.x, y.x))
Base.copy(x::CustomArray) = CustomArray(copy(x.x))
Base.length(x::CustomArray) = length(x.x)
Base.isempty(x::CustomArray) = isempty(x.x)
Base.eltype(x::CustomArray) = eltype(x.x)
Base.zero(x::CustomArray) = CustomArray(zero(x.x))
Base.fill!(x::CustomArray, y) = CustomArray(fill!(x.x, y))
Base.getindex(x::CustomArray, i) = getindex(x.x, i)
Base.setindex!(x::CustomArray, v, idx) = setindex!(x.x, v, idx)
Base.mapreduce(f, op, x::CustomArray; kwargs...) = mapreduce(f, op, x.x; kwargs...)
Base.any(f::Function, x::CustomArray; kwargs...) = any(f, x.x; kwargs...)
Base.all(f::Function, x::CustomArray; kwargs...) = all(f, x.x; kwargs...)
Base.:(==)(x::CustomArray, y::CustomArray) = x.x == y.x
Base.:(*)(x::Number, y::CustomArray) = CustomArray(x*y.x)
Base.:(/)(x::CustomArray, y::Number) = CustomArray(x.x/y)
LinearAlgebra.norm(x::CustomArray) = norm(x.x)

struct CustomStyle{N} <: Broadcast.BroadcastStyle where {N} end
CustomStyle(::Val{N}) where N = CustomStyle{N}()
CustomStyle{M}(::Val{N}) where {N,M} = NoIndexStyle{N}()
Base.BroadcastStyle(::Type{<:CustomArray{T,N}}) where {T,N} = CustomStyle{N}()
Broadcast.BroadcastStyle(::CustomStyle{N}, ::Broadcast.DefaultArrayStyle{0}) where {N} = CustomStyle{N}()
Base.similar(bc::Base.Broadcast.Broadcasted{CustomStyle{N}}, ::Type{ElType}) where {N, ElType} = CustomArray(similar(Array{ElType, N}, axes(bc)))
Base.Broadcast._broadcast_getindex(x::CustomArray, i) = x.x[i]
Base.Broadcast.extrude(x::CustomArray) = x
Base.Broadcast.broadcastable(x::CustomArray) = x

@inline function Base.copyto!(dest::CustomArray, bc::Base.Broadcast.Broadcasted{<:CustomStyle})
    axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
    bc′ = Base.Broadcast.preprocess(dest, bc)
    dest′ = dest.x
    @simd for I in 1:length(dest′)
        @inbounds dest′[I] = bc′[I]
    end
    return dest
end
@inline function Base.copy(bc::Base.Broadcast.Broadcasted{<:CustomStyle})
    bcf = Broadcast.flatten(bc)
    x = find_x(bcf)
    data = zeros(eltype(x), size(x))
    @inbounds @simd for I in 1:length(x)
        data[I] = bcf[I]
    end
    return CustomArray(data)
end
find_x(bc::Broadcast.Broadcasted) = find_x(bc.args)
find_x(args::Tuple) = find_x(find_x(args[1]), Base.tail(args))
find_x(x) = x
find_x(::Any, rest) = find_x(rest)
find_x(x::CustomArray, rest) = x.x

RecursiveArrayTools.recursive_unitless_bottom_eltype(x::CustomArray) = eltype(x)
RecursiveArrayTools.recursivecopy!(dest::CustomArray, src::CustomArray) = copyto!(dest, src)
RecursiveArrayTools.recursivecopy(x::CustomArray) = copy(x)
RecursiveArrayTools.recursivefill!(x::CustomArray, a) = fill!(x, a)

Base.show_vector(io::IO, x::CustomArray) = Base.show_vector(io, x.x)

Base.show(io::IO, x::CustomArray) = (print(io, "CustomArray");show(io, x.x))
function Base.show(io::IO, ::MIME"text/plain", x::CustomArray)
    println(io, Base.summary(x), ":")
    Base.print_array(io, x.x)
end

You can solve this defined type on ODEProblems, but not on SDEProblems:

ca0 = CustomArray(ones(10))
prob = SDEProblem((du, u, p, t)->copyto!(du, u),(du, u, p, t)->copyto!(du, u), ca0, (0.0,1.0))
sol = solve(prob, EM(), dt=1//2^4)

Error & Stacktrace ⚠️

ERROR: MethodError: no method matching /(::CustomArray{Float64, 1}, ::CustomArray{Float64, 1})

Closest candidates are:
  /(::ChainRulesCore.NotImplemented, ::Any)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/I1EbV/src/tangent_types/notimplemented.jl:42
  /(::Any, ::ChainRulesCore.NotImplemented)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/I1EbV/src/tangent_types/notimplemented.jl:43
  /(::ChainRulesCore.AbstractZero, ::Any)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/I1EbV/src/tangent_types/abstract_zero.jl:31
  ...

Stacktrace:
 [1] __init(_prob::SDEProblem{…}, alg::EM{…}, timeseries_init::Vector{…}, ts_init::Vector{…}, ks_init::Type, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_noise::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Rational{…}, adaptive::Bool, gamma::Int64, abstol::Nothing, reltol::Nothing, qmin::Int64, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta2::Nothing, beta1::Nothing, qoldinit::Int64, controller::Nothing, fullnormalize::Bool, failfactor::Int64, delta::Rational{…}, maxiters::Int64, dtmax::Float64, dtmin::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, force_dtmin::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, initialize_integrator::Bool, seed::UInt64, alias_u0::Bool, alias_jumps::Bool, kwargs::@Kwargs{})
   @ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/M3bKo/src/solve.jl:286
 [2] __solve(prob::SDEProblem{…}, alg::EM{…}, timeseries::Vector{…}, ts::Vector{…}, ks::Nothing, recompile::Type{…}; kwargs::@Kwargs{…})
   @ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/M3bKo/src/solve.jl:6
 [3] solve_call(_prob::SDEProblem{…}, args::EM{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:612
 [4] solve_up(prob::SDEProblem{…}, sensealg::Nothing, u0::CustomArray{…}, p::SciMLBase.NullParameters, args::EM{…}; kwargs::@Kwargs{…})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1080
 [5] solve(prob::SDEProblem{…}, args::EM{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003
 [6] top-level scope
   @ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.

Environment (please complete the following information):

Status `~/Documents/Julia Packages/QuantumJulia/Project.toml`
  [4c88cf16] Aqua v0.8.7
  [4fba245c] ArrayInterface v7.14.0
  [6e4b80f9] BenchmarkTools v1.5.0
  [0c46a032] DifferentialEquations v7.13.0
  [ffbed154] DocStringExtensions v0.9.3
  [e30172f5] Documenter v1.5.0
  [daee34ce] DocumenterCitations v1.3.3
  [7a1cc6ca] FFTW v1.8.0
  [7034ab61] FastBroadcast v0.3.5
  [1a297f60] FillArrays v1.11.0
  [f6369f11] ForwardDiff v0.10.36
⌃ [e9467ef8] GLMakie v0.9.11
  [c3a54625] JET v0.9.7
  [8ac3fa9e] LRUCache v1.6.1
  [23fbe1c1] Latexify v0.16.4
  [16fef848] LiveServer v1.3.1
  [1914dd2f] MacroTools v0.5.13
  [f9640e96] MultiScaleArrays v1.12.0
  [1dea7af3] OrdinaryDiffEq v6.87.0
  [e4faabce] PProf v3.1.0
  [32113eaa] PkgBenchmark v0.2.12
  [d330b81b] PyPlot v2.11.5
  [0525e862] QuantumClifford v0.9.7
  [5717a53b] QuantumInterface v0.3.4 `QuantumInterface.jl`
  [6e0679c1] QuantumOptics v1.1.1 `QuantumOptics.jl`
  [4f57444f] QuantumOpticsBase v0.5.1 `QuantumOpticsBase.jl`
  [efa7fd63] QuantumSymbolics v0.3.4 `QuantumSymbolics.jl`
  [2576dda1] RandomMatrices v0.5.5
  [731186ca] RecursiveArrayTools v3.26.0
  [295af30f] Revise v3.5.17
  [1bc83da4] SafeTestsets v0.1.0
  [2913bbd2] StatsBase v0.34.3
  [789caeaf] StochasticDiffEq v6.67.0
  [5e0ebb24] Strided v2.1.0
  [4db3bf67] StridedViews v0.3.1
⌅ [d1185830] SymbolicUtils v2.1.2
  [0c5d862f] Symbolics v5.34.0
  [ade2ca70] Dates
  [37e2e46d] LinearAlgebra
  [9abbd945] Profile
  [2f01184e] SparseArrays v1.10.0
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated`

Additional context

For context, I am working to define a broadcast interface for QuantumOptics.jl types (which wrap around arrays and basis information) that integrates with SciML.