SciML / StochasticDiffEq.jl

Solvers for stochastic differential equations which connect with the scientific machine learning (SciML) ecosystem
Other
237 stars 65 forks source link

Using CuSparseMatrixCSC for `noise_rate_prototype` fails #557

Open jack-dunham opened 7 months ago

jack-dunham commented 7 months ago

Describe the bug 🐞

I am unsure if this is a bug, or simply lack of support but calling solve with a noise_rate_prototype of type CuSparseMatrixCSC fails when constructing the prototype for noise increment (It would seem).

Expected behavior

Matrices from CUDA.CUSPARSE can be used for noise_rate_prototype.

Minimal Reproducible Example πŸ‘‡

using StochasticDiffEq
using SparseArrays
using CUDA

function main(u0, W)
    prob = SDEProblem(f, g, u0, (0.0, 1.0), noise_rate_prototype = W)
    solve(prob, EM(); dt=0.1)
end

function f(du, u, p, t)
    for (i,u_i) in enumerate(u)
        du[i] = u_i
    end
end
function g(du, u, p, t) 
    for (i,u_i) in enumerate(u)
        du[i, 1] = u_i
    end
end
N = 100

W = zeros(Float32, N, N)
W[:,1] .= ones(N)

W_d = CuArray(W)
W_sparse = sparse(W)
W_d_sparse = CUDA.CUSPARSE.CuSparseMatrixCSC(W_sparse)

u0 = fill(1f0,N)
u0_d = CuArray(u0)

main(u0,W) # works
main(u0,W_sparse) # works
main(u0_d, W_d) # works
main(u0_d, W_d_sparse) # fails!

Error & Stacktrace ⚠️

broadcast with sparse arrays is currently only implemented for CSR and CSC matrices

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] copy(bc::Base.Broadcast.Broadcasted{CUDA.CUSPARSE.CuSparseVecStyle, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Bool, CUDA.CUSPARSE.CuSparseVector{Float32, Int32}}})
    @ CUDA.CUSPARSE ~/.julia/packages/CUDA/YIj5X/lib/cusparse/broadcast.jl:473
  [3] materialize
    @ ./broadcast.jl:873 [inlined]
  [4] __init(_prob::SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, Nothing, SDEFunction{true, SciMLBase.FullSpecialize, typeof(f), typeof(g), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(g), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}}, alg::EM{true}, timeseries_init::Vector{Any}, ts_init::Vector{Any}, ks_init::Type, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_noise::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, adaptive::Bool, gamma::Int64, abstol::Nothing, reltol::Nothing, qmin::Int64, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta2::Nothing, beta1::Nothing, qoldinit::Int64, controller::Nothing, fullnormalize::Bool, failfactor::Int64, delta::Rational{Int64}, maxiters::Int64, dtmax::Float64, dtmin::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, force_dtmin::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tst...
    @ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/WPfDF/src/solve.jl:302
  [5] __solve(prob::SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, Nothing, SDEFunction{true, SciMLBase.FullSpecialize, typeof(f), typeof(g), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(g), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}}, alg::EM{true}, timeseries::Vector{Any}, ts::Vector{Any}, ks::Nothing, recompile::Type{Val{true}}; kwargs::Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:dt,), Tuple{Float64}}})
    @ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/WPfDF/src/solve.jl:6
  [6] solve_call(_prob::SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, Nothing, SDEFunction{true, SciMLBase.FullSpecialize, typeof(f), typeof(g), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(g), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}}, args::EM{true}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:dt,), Tuple{Float64}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/SlYdg/src/solve.jl:561
  [7] solve_up(prob::SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, Nothing, SDEFunction{true, SciMLBase.FullSpecialize, typeof(f), typeof(g), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(g), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}}, sensealg::Nothing, u0::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, p::SciMLBase.NullParameters, args::EM{true}; kwargs::Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:dt,), Tuple{Float64}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/SlYdg/src/solve.jl:1010
  [8] solve(prob::SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, Nothing, SDEFunction{true, SciMLBase.FullSpecialize, typeof(f), typeof(g), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(g), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}}, args::EM{true}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{true}, kwargs::Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:dt,), Tuple{Float64}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/SlYdg/src/solve.jl:933
  [9] main(u0::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, W::CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32})
    @ Main ./In[40]:7
 [10] top-level scope
    @ In[40]:36

Environment (please complete the following information):

Status `~/.julia/environments/v1.9/Project.toml`
  [6e4b80f9] BenchmarkTools v1.3.2
  [052768ef] CUDA v5.1.1
  [7073ff75] IJulia v1.24.2
  [789caeaf] StochasticDiffEq v6.63.2
Status `~/.julia/environments/v1.9/Manifest.toml`
  [47edcb42] ADTypes v0.2.5
  [621f4979] AbstractFFTs v1.5.0
  [7d9f7c33] Accessors v0.1.33
  [79e6a3ab] Adapt v3.7.1
  [ec485272] ArnoldiMethod v0.2.0
  [4fba245c] ArrayInterface v7.6.1
  [a9b6321e] Atomix v0.1.0
  [ab4f0b2a] BFloat16s v0.4.2
  [6e4b80f9] BenchmarkTools v1.3.2
  [62783981] BitTwiddlingConvenienceFunctions v0.1.5
βŒ… [fa961155] CEnum v0.4.2
  [2a0fbf3d] CPUSummary v0.2.4
  [052768ef] CUDA v5.1.1
  [1af6417a] CUDA_Runtime_Discovery v0.2.2
  [49dc2e85] Calculus v0.5.1
  [fb6a15b2] CloseOpenIntervals v0.1.12
  [3da002f7] ColorTypes v0.11.4
  [5ae59095] Colors v0.12.10
  [38540f10] CommonSolve v0.2.4
  [bbf7d656] CommonSubexpressions v0.3.0
  [34da2185] Compat v4.10.1
  [a33af91c] CompositionsBase v0.1.2
  [2569d6c7] ConcreteStructs v0.2.3
  [8f4d0f93] Conda v1.10.0
  [187b0558] ConstructionBase v1.5.4
  [adafc99b] CpuId v0.3.1
  [a8cc5b0e] Crayons v4.1.1
  [9a962f9c] DataAPI v1.15.0
  [a93c6f00] DataFrames v1.6.1
  [864edb3b] DataStructures v0.18.15
  [e2d170a0] DataValueInterfaces v1.0.0
  [8bb1440f] DelimitedFiles v1.9.1
  [2b5f629d] DiffEqBase v6.141.0
  [77a26b50] DiffEqNoiseProcess v5.19.0
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
  [b4f34e82] Distances v0.10.11
  [31c24e10] Distributions v0.25.103
  [ffbed154] DocStringExtensions v0.9.3
  [fa6b7ba4] DualNumbers v0.6.8
  [4e289a0a] EnumX v1.0.4
  [f151be2c] EnzymeCore v0.6.4
  [d4d017d3] ExponentialUtilities v1.25.0
  [e2ba6199] ExprTools v0.1.10
  [7034ab61] FastBroadcast v0.2.8
  [9aa1b823] FastClosures v0.3.2
  [29a986be] FastLapackInterface v2.0.0
  [1a297f60] FillArrays v1.8.0
  [6a86dc24] FiniteDiff v2.21.1
  [53c48c17] FixedPointNumbers v0.8.4
  [f6369f11] ForwardDiff v0.10.36
  [069b7b12] FunctionWrappers v1.1.3
  [77dc65aa] FunctionWrappersWrappers v0.1.3
  [0c68f7d7] GPUArrays v9.1.0
  [46192b85] GPUArraysCore v0.1.5
  [61eb1bfa] GPUCompiler v0.25.0
  [c145ed77] GenericSchur v0.5.3
  [86223c79] Graphs v1.9.0
  [3e5b6fbb] HostCPUFeatures v0.1.16
  [34004b35] HypergeometricFunctions v0.3.23
  [7073ff75] IJulia v1.24.2
  [615f187c] IfElse v0.1.1
  [d25df0c9] Inflate v0.1.4
  [842dd82b] InlineStrings v1.4.0
  [18e54dd8] IntegerMathUtils v0.1.2
  [3587e190] InverseFunctions v0.1.12
  [41ab1584] InvertedIndices v1.3.0
  [92d709cd] IrrationalConstants v0.2.2
  [82899510] IteratorInterfaceExtensions v1.0.0
  [692b3bcd] JLLWrappers v1.5.0
  [682c06a0] JSON v0.21.4
  [ccbc3e58] JumpProcesses v9.8.0
  [ef3ab10e] KLU v0.4.1
  [63c18a36] KernelAbstractions v0.9.13
  [ba0b0d4f] Krylov v0.9.4
  [929cbde3] LLVM v6.4.0
  [8b046642] LLVMLoopInfo v1.0.0
  [b964fa9f] LaTeXStrings v1.3.1
  [73f95e8e] LatticeRules v0.0.1
  [10f19ff3] LayoutPointers v0.1.15
  [50d2b5c4] Lazy v0.15.1
  [2d8b4e74] LevyArea v1.0.0
  [d3d80556] LineSearches v7.2.0
  [7ed4a6bd] LinearSolve v2.20.0
  [2ab3a3ac] LogExpFunctions v0.3.26
  [bdcacae8] LoopVectorization v0.12.166
  [1914dd2f] MacroTools v0.5.11
  [d125e4d3] ManualMemory v0.1.8
  [739be429] MbedTLS v1.1.9
  [e1d29d7a] Missings v1.1.0
  [46d2c3a1] MuladdMacro v0.2.4
  [d41bc354] NLSolversBase v7.8.3
  [2774e3e8] NLsolve v4.5.1
  [5da4648a] NVTX v0.3.3
  [77ba4419] NaNMath v1.0.2
  [8913a72c] NonlinearSolve v2.8.2
  [6fe1bfb0] OffsetArrays v1.12.10
  [429524aa] Optim v1.7.8
  [bac558e1] OrderedCollections v1.6.3
  [1dea7af3] OrdinaryDiffEq v6.59.3
  [90014a1f] PDMats v0.11.30
  [65ce6f38] PackageExtensionCompat v1.0.2
  [d96e819e] Parameters v0.12.3
  [69de0a69] Parsers v2.8.0
  [e409e4f3] PoissonRandom v0.4.4
  [f517fe37] Polyester v0.7.9
  [1d0040c9] PolyesterWeave v0.2.1
  [2dfb63ee] PooledArrays v1.4.3
  [85a6dd25] PositiveFactorizations v0.2.4
  [d236fae5] PreallocationTools v0.4.12
  [aea7be01] PrecompileTools v1.2.0
  [21216c6a] Preferences v1.4.1
  [08abe8d2] PrettyTables v2.3.1
  [27ebfcd6] Primes v0.5.5
  [1fd47b50] QuadGK v2.9.1
  [8a4e6c94] QuasiMonteCarlo v0.3.3
  [74087812] Random123 v1.6.1
  [e6cf234a] RandomNumbers v1.5.3
  [3cdcf5f2] RecipesBase v1.3.4
  [731186ca] RecursiveArrayTools v2.38.10
  [f2c3362d] RecursiveFactorization v0.2.21
  [189a3867] Reexport v1.2.2
  [ae029012] Requires v1.3.0
  [ae5879a3] ResettableStacks v1.1.1
  [79098fc4] Rmath v0.7.1
  [7e49a35a] RuntimeGeneratedFunctions v0.5.12
  [94e857df] SIMDTypes v0.1.0
  [476501e8] SLEEFPirates v0.6.42
  [0bca4576] SciMLBase v2.9.1
  [e9a6253c] SciMLNLSolve v0.1.9
  [c0aeaf25] SciMLOperators v0.3.7
  [6c6a2e73] Scratch v1.2.1
  [91c51154] SentinelArrays v1.4.1
  [efcf1570] Setfield v1.1.1
  [727e6d20] SimpleNonlinearSolve v0.1.25
  [699a6c99] SimpleTraits v0.9.4
  [ce78b400] SimpleUnPack v1.1.0
  [ed01d8cd] Sobol v1.5.0
  [b85f4697] SoftGlobalScope v1.1.0
  [a2af1166] SortingAlgorithms v1.2.0
  [47a9eef4] SparseDiffTools v2.13.0
  [e56a9233] Sparspak v0.3.9
  [276daf66] SpecialFunctions v2.3.1
  [aedffcd0] Static v0.8.8
  [0d7ed370] StaticArrayInterface v1.4.1
  [90137ffa] StaticArrays v1.7.0
  [1e83bf80] StaticArraysCore v1.4.2
  [82ae8749] StatsAPI v1.7.0
  [2913bbd2] StatsBase v0.34.2
  [4c63d2b9] StatsFuns v1.3.0
  [789caeaf] StochasticDiffEq v6.63.2
  [7792a7ef] StrideArraysCore v0.5.2
  [892a3eda] StringManipulation v0.3.4
  [2efcf032] SymbolicIndexingInterface v0.2.2
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.11.1
  [8290d209] ThreadingUtilities v0.5.2
  [a759f4b9] TimerOutputs v0.5.23
  [a2a6695c] TreeViews v0.3.0
  [d5829a12] TriangularSolve v0.1.20
  [410a4b4d] Tricks v0.1.8
  [781d530d] TruncatedStacktraces v1.4.0
  [3a884ed6] UnPack v1.0.2
  [013be700] UnsafeAtomics v0.2.1
  [d80eeb9a] UnsafeAtomicsLLVM v0.1.3
  [3d5dd08c] VectorizationBase v0.21.65
  [81def892] VersionParsing v1.3.0
  [19fa3120] VertexSafeGraphs v0.2.0
  [c2297ded] ZMQ v1.2.2
  [4ee394cb] CUDA_Driver_jll v0.7.0+0
  [76a88914] CUDA_Runtime_jll v0.10.1+0
  [1d5cc7b8] IntelOpenMP_jll v2023.2.0+0
  [9c1d0b0a] JuliaNVTXCallbacks_jll v0.2.1+0
  [dad2f222] LLVMExtra_jll v0.0.27+1
  [856f044c] MKL_jll v2023.2.0+0
  [e98f9f5b] NVTX_jll v3.1.0+2
  [efe28fd5] OpenSpecFun_jll v0.5.5+0
  [f50d1b31] Rmath_jll v0.4.0+0
  [8f1865be] ZeroMQ_jll v4.3.4+0
  [a9144af2] libsodium_jll v1.0.20+0
  [0dad84c5] ArgTools v1.1.1
  [56f22d72] Artifacts
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8ba89e20] Distributed
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching
  [9fa8497b] Future
  [b77e0a4c] InteractiveUtils
  [4af54fe1] LazyArtifacts
  [b27032c2] LibCURL v0.6.4
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [a63ad114] Mmap
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.9.2
  [de0858da] Printf
  [9abbd945] Profile
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization
  [1a1011a3] SharedArrays
  [6462fe0b] Sockets
  [2f01184e] SparseArrays
  [10745b16] Statistics v1.9.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
  [e66e0078] CompilerSupportLibraries_jll v1.0.5+0
  [deac9b47] LibCURL_jll v8.4.0+0
  [29816b5a] LibSSH2_jll v1.11.0+1
  [c8ffd9c3] MbedTLS_jll v2.28.2+0
  [14a3606d] MozillaCACerts_jll v2022.10.11
  [4536629a] OpenBLAS_jll v0.3.21+4
  [05823500] OpenLibm_jll v0.8.1+0
  [bea87d4a] SuiteSparse_jll v5.10.1+6
  [83775a58] Zlib_jll v1.2.13+0
  [8e850b90] libblastrampoline_jll v5.8.0+0
  [8e850ede] nghttp2_jll v1.52.0+1
  [3f19e933] p7zip_jll v17.4.0+0
Info Packages marked with βŒ… have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated -m`
Julia Version 1.9.4
Commit 8e5136fa297 (2023-11-14 08:46 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 2 Γ— Intel(R) Xeon(R) CPU @ 2.00GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake-avx512)
  Threads: 3 on 2 virtual cores
Environment:
  LD_LIBRARY_PATH = /usr/lib64-nvidia
  JULIA_NUM_THREADS = 2

Additional context


CUDA runtime 12.3, artifact installation
CUDA driver 12.3
NVIDIA driver 525.105.17, originally for CUDA 12.0

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+525.105.17

Julia packages: 
- CUDA: 5.1.1
- CUDA_Driver_jll: 0.7.0+0
- CUDA_Runtime_jll: 0.10.1+0

Toolchain:
- Julia: 1.9.4
- LLVM: 14.0.6

1 device:
  0: Tesla T4 (sm_75, 14.596 GiB / 15.000 GiB available)
ChrisRackauckas commented 6 months ago

This needs more from the CUDA library. MWE:

using CUDA
A = zeros(Float32, N, N)
A[:,1] .= ones(N)
W_d = CuArray(W)
W_sparse = sparse(W)
A = CUDA.CUSPARSE.CuSparseMatrixCSC(sparse(A))
false .* A[1,:]
ChrisRackauckas commented 6 months ago

Tracking in https://github.com/JuliaGPU/CUDA.jl/issues/2209

jack-dunham commented 5 months ago

Hi there, apologies for opening this and then ghosting.

In response to this https://github.com/JuliaGPU/CUDA.jl/issues/2209#issuecomment-1941992408 in the tracked issue, why are we necessarily trying to zero a vector of similar type to the noise matrix? For example, although the noise rate matrix may well by sparse, the vector of resulting noise increments may not be sparse. Would it not be more sensible to zero a vector of same type as u0?

I am almost certainly misunderstanding something here.

Thanks in advance!

ChrisRackauckas commented 5 months ago

That's what's happening internal in the package that's failing.

jack-dunham commented 5 months ago

Yeah I don't understand why rand_prototype needs to be a vector of the similar type to noise_rate_prototype because I don't understand what the purpose of rand_prototype is.

The offending code is this: https://github.com/SciML/StochasticDiffEq.jl/blob/870e062edd2582a4d79b6c0dabe64510c9860c07/src/solve.jl#L298-L303

Line 299 will evaluate to false if u isa CuVector. I am picturing a solution like this:

if noise_rate_prototype isa CuSparseMatrix
    rand_prototype = CUDA.zeros(randEltype, size(noise_rate_prototype,2)
end

With the important caveat that I don't know what rand_prototype does. I see it is used in the construction of WienerProcess for certain algorithms and unused otherwise. I gave up trying to figure out what is purpose is within WienerProcess!

Cheers.

ChrisRackauckas commented 5 months ago

because I don't understand what the purpose of rand_prototype is.

rand_prototype is a prototype of what the random vector is supposed to be, i.e. size shape and type. In theory this should always be dense but we don't have a way of densifying it from the rows generically. You cannot CUDA.zeros since it needs to be generic type matching to a row of noise_rate_prototype.

jack-dunham commented 5 months ago

since it needs to be generic type matching to a row of noise_rate_prototype.

I am curious why this is the case? Should it not match the typeof u rather than the type of the object you get if you slice the noise matrix? Then by inputting dense u::V and sparse noise_rate_prototype::M the integrator calls the method

*(mat::M, vec::V)

At least this is how I assumed it would work by reading the docs.

If we zero a sparse CUDA array for the random vector, surely we then run into a similar of problem of trying to add a CuVector to a CuSparseVector? Does this kernel exist in CUDA.jl?

ChrisRackauckas commented 5 months ago

It needs to match the length of the row but the type only has to be compatible with computations of u. If you are doing automatic differentiation, u would be dual while rand_prototype would want to stay non-dual.

As a "most case" solution, adapt(DiffEqBase.parameterless_type(u),zeros(randElType,size(noise_rate_prototype,2))) probably works for this case as well, so maybe that branch can just be removed.