Closed sdobber closed 3 years ago
Yes, here's the situation with callbacks.
ForwardDiffSensitivity works beautifully with Discrete and Continuous callbacks.
For Tracker, you need to obey Tracker semantics. Here's a working version with Tracker:
using DiffEqSensitivity, DifferentialEquations, DiffEqFlux, Flux, Tracker
using Plots
sensealg = TrackerAdjoint()
datalength = 100
tspan = (0.0,100.0)
t = range(tspan[1],tspan[2],length=datalength)
target = 3.0*(1:datalength)./datalength # some dummy data to fit to
cbinput = rand(1, datalength) #some external ODE contribution
pmodel = Chain(
Dense(2, 10, initW=zeros),
Dense(10, 2, initW=zeros))
p, re = Flux.destructure(pmodel)
dudt(u,p,t) = re(p)(u)
# callback changes the first component of the solution every time
# t is an integer
function affect!(integrator, cbinput)
event_index = round(Int,integrator.t)
x = [integrator.u[1]+0.2*cbinput[event_index],integrator.u[2]]
integrator.u = integrator.u isa Tracker.TrackedArray ? Tracker.collect(x) : x
end
callback = PresetTimeCallback(collect(1:datalength),(int)->affect!(int, cbinput))
# ODE with Callback
prob = ODEProblem(dudt,[0.0, 1.0],tspan, p, saveat=2, callback = callback)
function predict_n_ode(p)
arr = Array(concrete_solve(prob, Tsit5(), [0.0, 1.0],
p, sensealg=sensealg ))[1,2:2:end]
return arr[1:datalength]
end
function loss_n_ode()
pred = predict_n_ode(p)
loss = sum(abs2,target .- pred)./datalength
end
cb = function () #callback function to observe training
pred = predict_n_ode(p)
display(loss_n_ode())
pl = plot(1:datalength,target,label="data")
plot!(pl,t,pred,label="prediction")
display(pl)
end
Flux.train!(loss_n_ode, Flux.params(p), Iterators.repeated((), 10), ADAM(0.01), cb = cb
)
We're trying to phase out Tracker as much as possible because it can be difficult to work with, but that'll hold you over in a way that will scale.
As for the hardcoded adjoints, for general callbacks we need to handle the derivative of the callback time and discontinuity size if there is a parameter inside of the condition
or affect!
function. That is this issue: https://github.com/JuliaDiffEq/DiffEqSensitivity.jl/issues/4 and it's not easy, but it's doable. I hope we get around to it this summer.
Your case doesn't actually have a parameter in the condition or affect, so it should "just work". It seems like this is an issue with how PeriodicCallback is setting the tstops though, so I'll investigate that.
Together https://github.com/JuliaDiffEq/DiffEqCallbacks.jl/pull/74 + https://github.com/JuliaDiffEq/DiffEqSensitivity.jl/pull/200 fixes InterpolatingAdjoint.
Training Neural ODEs with callbacks currently fails for most of the sensitivity algorithms available to
concrete_solve
: OnlyForwardDiffSensitivity
works for the following MWE. However, this method scales badly with the number of parameters. Other sensitivity methods either do not update the parameters, or throw errors. (This issue is based on this discussion on Slack.)MWE:
Changing
sensealg
to other algorithms breaks the training.