A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
Describe the bug π
Hi,
I am trying to optimize a Split ODE with a NN embedded. While I can run the inference, the optimization w/ Zygote throws an UndefRefError. I am not sure if this is the right place to a bug report or if the bug is rather in Zygote. Since the same optimization works fine with an ODEProblem, I suspect that it is not Zygote though.
Expected behavior
Not throw an error ;)
Minimal Reproducible Example π
using Flux, Optimization, OrdinaryDiffEq, OptimizationOptimisers, SciMLSensitivity
##
N = 2^2
data = randn(Float32, 3, 100, N)
u0 = repeat(Float32[0, 0, 1], 1, N)
##
ann = Chain(Dense(3, 50, tanh), Dense(50, 3))
pann, st = Flux.destructure(ann)
train_loader = Flux.Data.DataLoader((data, u0); batchsize=N)
##
const H1 = Float32[ 0 0 1; 0 0 0; -1 0 0]
function dudt_rot(u, pann, t)
return H1 * u
end
function dudt_relax(u, pann, t)
return st(pann)(u)
end
##
function predict_neuralode(pann, u0)
sol1 = solve(
SplitODEProblem{false}(
SplitFunction{false}(
ODEFunction{false}(dudt_rot),
ODEFunction{false}(dudt_relax),
), u0, (0f0, 1f0)),
Rosenbrock23(), p = pann, saveat = collect(range(0f0, 1f0, 100)))
return permutedims(Array(sol1), (1, 3, 2))
end
predict_neuralode(pann, u0)
##
function loss_neuralode(pann, data, u0)
pred = predict_neuralode(pann, u0)
loss = sum(abs2, data .- pred)
return loss, pred
end
loss_neuralode(pann, data, u0)
##
optf = Optimization.OptimizationFunction((pann, pnothing, data, u0) -> loss_neuralode(pann, data, u0), Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, pann)
result_neuralode = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.001), train_loader)
Transferring to SciMLSensitivity. Indeed it just hasn't been handled yet. It's nothing fundamental but we need to write a specialization for SplitFunction.
Describe the bug π Hi, I am trying to optimize a Split ODE with a NN embedded. While I can run the inference, the optimization w/ Zygote throws an UndefRefError. I am not sure if this is the right place to a bug report or if the bug is rather in Zygote. Since the same optimization works fine with an ODEProblem, I suspect that it is not Zygote though.
Expected behavior Not throw an error ;)
Minimal Reproducible Example π
Error & Stacktrace β οΈ
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
versioninfo()
Additional context Thanks for looking into this!