A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
Reverse adjoints for SDEs only works with 'TrackerAdjoint()' and only on CPU. 🐞
Training Large (e.g, Neural) SDEs on GPUs fails. The only working solution is 'TrackerAdjoint()' and this only currently works on CPU.
None of the continuous adjoints methods, e.g. 'InterpolatingAdjoint()' or 'BackwardsolveAdjoint()' work either on cpu or gpu.
I suspect the problem with the continuous methods is the shape of the noise during the backwards solve.
W.r.t. 'TrackerAdjoint()' on gpus, something is transferred to the CPU during the backwards pass. This also happens for ODEs btw.
MWE
using DifferentialEquations, Lux, ComponentArrays, Random, SciMLSensitivity, Zygote, BenchmarkTools, LuxCUDA, CUDA,
OptimizationOptimisers
dev = gpu_device()
sensealg = TrackerAdjoint() #This works only on cpu
data = rand32(32,100,512) |> dev
x₀ = rand32(32,512) |> dev
ts = range(0.0f0, 1.0f0, length=100)
drift = Dense(32, 32, tanh)
diffusion = Scale(32, sigmoid)
basic_tgrad(u, p, t) = zero(u)
struct NeuralSDE{D, F} <: Lux.AbstractExplicitContainerLayer{(:drift, :diffusion)}
drift::D
diffusion::F
solver
tspan
sensealg
end
function (model::NeuralSDE)(x₀, ts, p, st)
μ(u, p, t) = model.drift(u, p.drift, st.drift)[1]
σ(u, p, t) = model.diffusion(u, p.diffusion, st.diffusion)[1]
func = SDEFunction{false}(μ, σ; tgrad=basic_tgrad)
prob = SDEProblem{false}(func, x₀, model.tspan, p)
sol = solve(prob, model.solver; saveat=ts, dt=0.01f0, sensealg = model.sensealg)
return permutedims(cat(sol.u..., dims=3), (1,3,2))
end
function loss!(p, data)
pred = model(x₀, ts, p, st)
l = sum(abs2, data .- pred)
return l, st, pred
end
rng = Random.default_rng()
model = NeuralSDE(drift, diffusion, EM(), (0.0f0, 1.0f0), sensealg)
p, st = Lux.setup(rng, model)
p = p |> ComponentArray{Float32} |> dev
adtype = AutoZygote()
optf = OptimizationFunction((p, _ ) -> loss!(p, data), adtype)
optproblem = OptimizationProblem(optf, p)
result = Optimization.solve(optproblem, ADAMW(5e-4), maxiters=10)
Reverse adjoints for SDEs only works with 'TrackerAdjoint()' and only on CPU. 🐞
Training Large (e.g, Neural) SDEs on GPUs fails. The only working solution is 'TrackerAdjoint()' and this only currently works on CPU. None of the continuous adjoints methods, e.g. 'InterpolatingAdjoint()' or 'BackwardsolveAdjoint()' work either on cpu or gpu.
MWE
Error & Stacktrace
I am using the latest releases for the packages and Julia 1.10.4.