A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
this is strongly related to issue #564 but I can narrow the error down now, see the following MWE which basically trains a minimal NeuralODE and is using a FunctionCallingCallback that does nothing:
using Flux, DiffEqFlux
using DifferentialEquations
using DiffEqCallbacks
using DiffEqFlux: ODEFunction, basic_tgrad, ODEProblem, ZygoteVJP, InterpolatingAdjoint, solve
numStates = 2
tspan = (0.0f0, 5.0f0)
saveat = 0.0f0:0.1f0:5.0f0
optim = Adam()
x0 = [0.0f0, 0.0f0]
function f(x, p, t)
return neuralODE.re(p)(x)
end
function functionCalling(x, t, integrator)
# @info "Step: $(t)"
end
callback = FunctionCallingCallback(functionCalling;
func_everystep=true,
func_start=true)
### ML
posData = zeros(Float64, length(saveat)) # target position data ("flat line"), not a good training target but this doens't matter here ;-)
# MSE between net output and data
function losssum(prob, neuralODE, sensealg)
solution = solve(prob, neuralODE.args...; sensealg=sensealg, saveat=saveat, callback=CallbackSet(callback), neuralODE.kwargs...)
posNet = collect(data[1] for data in solution.u)
return Flux.Losses.mse(posNet, posData)
end
net = Chain(Dense(numStates, numStates, identity),
Dense(numStates, numStates, identity),
Dense(numStates, numStates, identity))
sensealgs = [ForwardDiffSensitivity(;convert_tspan=true), InterpolatingAdjoint(autojacvec=ReverseDiffVJP())]
for sensealg in sensealgs
print("Sensealg: $(sensealg)\n")
global neuralODE
neuralODE = NeuralODE(net, tspan; saveat=saveat)
dudt_op(x, p, t) = f(x, p, t);
ff = ODEFunction{false}(dudt_op,tgrad=basic_tgrad);
prob = ODEProblem{false}(ff, x0, tspan, neuralODE.p);
p_net = Flux.params(neuralODE);
Flux.train!(()->losssum(prob, neuralODE, sensealg), p_net, Iterators.repeated((), 1), Adam());
end
Fails on current Julia (1.8.5) and most up-to-date-versions of SciML-Ecosystem.
Dear @ChrisRackauckas and @frankschae,
this is strongly related to issue #564 but I can narrow the error down now, see the following MWE which basically trains a minimal NeuralODE and is using a FunctionCallingCallback that does nothing:
Fails on current Julia (1.8.5) and most up-to-date-versions of SciML-Ecosystem.
Let me know if I can support. Best regards!
The error is: