SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
329 stars 71 forks source link

Reverse-mode AD for SDEs #1084

Open elgazzarr opened 1 month ago

elgazzarr commented 1 month ago

Reverse adjoints for SDEs only works with 'TrackerAdjoint()' and only on CPU. 🐞

Training Large (e.g, Neural) SDEs on GPUs fails. The only working solution is 'TrackerAdjoint()' and this only currently works on CPU. None of the continuous adjoints methods, e.g. 'InterpolatingAdjoint()' or 'BackwardsolveAdjoint()' work either on cpu or gpu.

MWE


using DifferentialEquations, Lux, ComponentArrays, Random, SciMLSensitivity, Zygote, BenchmarkTools, LuxCUDA, CUDA,
OptimizationOptimisers

dev = gpu_device()
sensealg = TrackerAdjoint()  #This works only on cpu

data = rand32(32,100,512) |> dev
x₀ = rand32(32,512) |> dev
ts = range(0.0f0, 1.0f0, length=100)
drift = Dense(32, 32, tanh)
diffusion = Scale(32, sigmoid)

basic_tgrad(u, p, t) = zero(u)

struct NeuralSDE{D, F} <: Lux.AbstractExplicitContainerLayer{(:drift, :diffusion)}
    drift::D
    diffusion::F
    solver
    tspan
    sensealg
end

function (model::NeuralSDE)(x₀, ts, p, st)
    μ(u, p, t) = model.drift(u, p.drift, st.drift)[1]
    σ(u, p, t) = model.diffusion(u, p.diffusion, st.diffusion)[1]
    func = SDEFunction{false}(μ, σ; tgrad=basic_tgrad)
    prob = SDEProblem{false}(func, x₀, model.tspan, p)
    sol = solve(prob, model.solver; saveat=ts, dt=0.01f0, sensealg = model.sensealg)
    return permutedims(cat(sol.u..., dims=3), (1,3,2))
end

function loss!(p, data)
    pred = model(x₀, ts, p, st)
    l = sum(abs2, data .- pred)
    return l, st, pred
end

rng = Random.default_rng()
model = NeuralSDE(drift, diffusion, EM(), (0.0f0, 1.0f0), sensealg)
p, st = Lux.setup(rng, model)
p = p |> ComponentArray{Float32} |> dev

adtype = AutoZygote()
optf = OptimizationFunction((p, _ ) -> loss!(p, data), adtype)
optproblem = OptimizationProblem(optf, p)
result = Optimization.solve(optproblem, ADAMW(5e-4), maxiters=10)

Error & Stacktrace

ERROR: LoadError: GPU compilation of MethodInstance for (::GPUArrays.var"#35#37")(::CUDA.CuKernelContext, ::CuDeviceMatrix{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}} which is not isbits.
      .x is of type Matrix{Float32} which is not isbits.

Stacktrace:
    [1] check_invocation(job::GPUCompiler.CompilerJob)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/validation.jl:92
    [2] macro expansion
      @ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:128 [inlined]
    [3] macro expansion
      @ ~/.julia/packages/TimerOutputs/Lw5SP/src/TimerOutput.jl:253 [inlined]
    [4] 
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:126
    [5] 
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:111
    [6] compile
      @ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:103 [inlined]
    [7] #1145
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:254 [inlined]
    [8] JuliaContext(f::CUDA.var"#1145#1148"{GPUCompiler.CompilerJob{…}}; kwargs::@Kwargs{})
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:52
    [9] JuliaContext(f::Function)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:42
   [10] compile(job::GPUCompiler.CompilerJob)
      @ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:253
   [11] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:237
   [12] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:151
   [13] macro expansion
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:369 [inlined]
   [14] macro expansion
      @ ./lock.jl:267 [inlined]
   [15] cufunction(f::GPUArrays.var"#35#37", tt::Type{Tuple{…}}; kwargs::@Kwargs{})
      @ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:364
   [16] cufunction
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:361 [inlined]
   [17] macro expansion
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:112 [inlined]
   [18] #launch_heuristic#1204
      @ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:17 [inlined]
   [19] launch_heuristic
      @ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:15 [inlined]
   [20] _copyto!
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:78 [inlined]
   [21] copyto!
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:44 [inlined]
   [22] copy
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:29 [inlined]
   [23] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{…}, Nothing, typeof(+), Tuple{…}})
      @ Base.Broadcast ./broadcast.jl:903
   [24] accum!(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/params.jl:46
   [25] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:134
   [26] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [27] #64
      @ ./tuple.jl:628 [inlined]
   [28] BottomRF
      @ ./reduce.jl:86 [inlined]
   [29] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
      @ Base ./reduce.jl:58
   [30] foldl_impl
      @ ./reduce.jl:48 [inlined]
   [31] mapfoldl_impl
      @ ./reduce.jl:44 [inlined]
   [32] mapfoldl
      @ ./reduce.jl:175 [inlined]
   [33] foldl
      @ ./reduce.jl:198 [inlined]
   [34] foreach
      @ ./tuple.jl:628 [inlined]
   [35] back_(g::Tracker.Grads, c::Tracker.Call{Tracker.var"#583#584"{…}, Tuple{…}}, Δ::Vector{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
   [36] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 1, CUDA.DeviceMemory}}, Δ::Vector{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
   [37] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [38] #64
      @ ./tuple.jl:628 [inlined]
   [39] BottomRF
      @ ./reduce.jl:86 [inlined]
   [40] _foldl_impl
      @ ./reduce.jl:58 [inlined]
   [41] foldl_impl
      @ ./reduce.jl:48 [inlined]
   [42] mapfoldl_impl(f::typeof(identity), op::Base.var"#64#65"{…}, nt::Nothing, itr::Base.Iterators.Zip{…})
      @ Base ./reduce.jl:44
   [43] mapfoldl(f::Function, op::Function, itr::Base.Iterators.Zip{Tuple{Tuple{…}, Tuple{…}}}; init::Nothing)
      @ Base ./reduce.jl:175
   [44] mapfoldl
      @ ./reduce.jl:175 [inlined]
   [45] foldl
      @ ./reduce.jl:198 [inlined]
   [46] foreach(::Function, ::Tuple{Tracker.Tracked{…}, Tracker.Tracked{…}}, ::Tuple{Vector{…}, Vector{…}})
      @ Base ./tuple.jl:628
   [47] back_(g::Tracker.Grads, c::Tracker.Call{Tracker.var"#552#555"{…}, Tuple{…}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
   [48] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
   [49] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [50] #64
      @ ./tuple.jl:628 [inlined]
   [51] BottomRF
      @ ./reduce.jl:86 [inlined]
   [52] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
      @ Base ./reduce.jl:58
--- the last 12 lines are repeated 98 more times ---
 [1229] foldl_impl
      @ ./reduce.jl:48 [inlined]
 [1230] mapfoldl_impl
      @ ./reduce.jl:44 [inlined]
 [1231] mapfoldl
      @ ./reduce.jl:175 [inlined]
 [1232] foldl
      @ ./reduce.jl:198 [inlined]
 [1233] foreach
      @ ./tuple.jl:628 [inlined]
 [1234] back_(g::Tracker.Grads, c::Tracker.Call{…}, Δ::RODESolution{…})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
 [1235] back(g::Tracker.Grads, x::Tracker.Tracked{…}, Δ::RODESolution{…})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
 [1236] #712
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:155 [inlined]
 [1237] #715
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:164 [inlined]
 [1238] (::SciMLSensitivity.var"#tracker_adjoint_backpass#368"{…})(ybar::RODESolution{…})
      @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4hOeN/src/concrete_solve.jl:1319
 [1239] ZBack
      @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [1240] (::Zygote.var"#kw_zpullback#53"{…})(dy::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [1241] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1242] (::Zygote.var"#2169#back#293"{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [1243] #solve#51
      @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
 [1244] (::Zygote.Pullback{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1245] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1246] (::Zygote.var"#2169#back#293"{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [1247] solve
      @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
 [1248] (::Zygote.Pullback{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1249] NeuralSDE
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:31 [inlined]
 [1250] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CuArray{Float32, 3, CUDA.DeviceMemory})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1251] loss!
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:36 [inlined]
 [1252] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1253] #39
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:48 [inlined]
 [1254] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1255] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1256] OptimizationFunction
      @ ~/.julia/packages/SciMLBase/rR75x/src/scimlfunctions.jl:3763 [inlined]
 [1257] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1258] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1259] #37
      @ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:94 [inlined]
 [1260] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1261] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1262] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1263] #39
      @ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97 [inlined]
 [1264] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1265] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [1266] gradient(f::Function, args::ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{…}}})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [1267] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
      @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97
 [1268] macro expansion
      @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [1269] macro expansion
      @ ~/.julia/packages/Optimization/fPKIF/src/utils.jl:32 [inlined]
 [1270] __solve(cache::OptimizationCache{…})
      @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [1271] solve!(cache::OptimizationCache{…})
      @ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:188
 [1272] solve(::OptimizationProblem{…}, ::OptimiserChain{…}; kwargs::@Kwargs{…})
      @ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:96
in expression starting at /home/artiintel/ahmelg/code/NeuroDynamics.jl/examples/mwe.jl:50
Some type information was truncated. Use `show(err)` to see complete types.

I am using the latest releases for the packages and Julia 1.10.4.

ChrisRackauckas commented 1 week ago

I think this is related to the ComponentArrays thing we just found. @avik-pal is looking into it.