SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
330 stars 70 forks source link

Neural ODEs with callbacks cannot be trained with most sensitivity algorithms #197

Closed sdobber closed 3 years ago

sdobber commented 4 years ago

Training Neural ODEs with callbacks currently fails for most of the sensitivity algorithms available to concrete_solve: Only ForwardDiffSensitivity works for the following MWE. However, this method scales badly with the number of parameters. Other sensitivity methods either do not update the parameters, or throw errors. (This issue is based on this discussion on Slack.)

MWE:

using DiffEqSensitivity, DifferentialEquations, DiffEqFlux, Flux, Tracker
using Plots

sensealg = ForwardDiffSensitivity()

datalength = 100
tspan = (0.0,100.0)
t = range(tspan[1],tspan[2],length=datalength)
target = 3.0*(1:datalength)./datalength  # some dummy data to fit to
cbinput = rand(1, datalength) #some external ODE contribution

pmodel = Chain(
    Dense(2, 10, initW=zeros),
    Dense(10, 2, initW=zeros))
p, re = Flux.destructure(pmodel)
dudt(u,p,t) = re(p)(u)

# callback changes the first component of the solution every time
# t is an integer
function affect!(integrator, cbinput)
    event_index = round(Int,integrator.t)
    integrator.u[1] += 0.2*cbinput[event_index]
end
callback = PresetTimeCallback(collect(1:datalength),(int)->affect!(int, cbinput))

# ODE with Callback
prob = ODEProblem(dudt,[0.0, 1.0],tspan, p, saveat=2, callback = callback)

function predict_n_ode(p)
    arr = Array(concrete_solve(prob, Tsit5(), [0.0, 1.0],
            p, sensealg=sensealg ))[1,2:2:end]
    return arr[1:datalength]
end

function loss_n_ode()
    pred = predict_n_ode(p)
    loss = sum(abs2,target .- pred)./datalength
end

cb = function () #callback function to observe training
  pred = predict_n_ode(p)
  display(loss_n_ode())
  pl = plot(1:datalength,target,label="data")
  plot!(pl,t,pred,label="prediction")
  display(pl)
end

Flux.train!(loss_n_ode, Flux.params(p), Iterators.repeated((), 10), ADAM(0.01), cb = cb)

Changing sensealg to other algorithms breaks the training.

ChrisRackauckas commented 4 years ago

Yes, here's the situation with callbacks.

ForwardDiffSensitivity works beautifully with Discrete and Continuous callbacks.

For Tracker, you need to obey Tracker semantics. Here's a working version with Tracker:

using DiffEqSensitivity, DifferentialEquations, DiffEqFlux, Flux, Tracker
using Plots

sensealg = TrackerAdjoint()

datalength = 100
tspan = (0.0,100.0)
t = range(tspan[1],tspan[2],length=datalength)
target = 3.0*(1:datalength)./datalength  # some dummy data to fit to
cbinput = rand(1, datalength) #some external ODE contribution

pmodel = Chain(
    Dense(2, 10, initW=zeros),
    Dense(10, 2, initW=zeros))
p, re = Flux.destructure(pmodel)
dudt(u,p,t) = re(p)(u)

# callback changes the first component of the solution every time
# t is an integer
function affect!(integrator, cbinput)
    event_index = round(Int,integrator.t)
    x = [integrator.u[1]+0.2*cbinput[event_index],integrator.u[2]]
    integrator.u = integrator.u isa Tracker.TrackedArray ? Tracker.collect(x) : x
end
callback = PresetTimeCallback(collect(1:datalength),(int)->affect!(int, cbinput))

# ODE with Callback
prob = ODEProblem(dudt,[0.0, 1.0],tspan, p, saveat=2, callback = callback)

function predict_n_ode(p)
    arr = Array(concrete_solve(prob, Tsit5(), [0.0, 1.0],
            p, sensealg=sensealg ))[1,2:2:end]
    return arr[1:datalength]
end

function loss_n_ode()
    pred = predict_n_ode(p)
    loss = sum(abs2,target .- pred)./datalength
end

cb = function () #callback function to observe training
  pred = predict_n_ode(p)
  display(loss_n_ode())
  pl = plot(1:datalength,target,label="data")
  plot!(pl,t,pred,label="prediction")
  display(pl)
end

Flux.train!(loss_n_ode, Flux.params(p), Iterators.repeated((), 10), ADAM(0.01), cb = cb
)

We're trying to phase out Tracker as much as possible because it can be difficult to work with, but that'll hold you over in a way that will scale.

As for the hardcoded adjoints, for general callbacks we need to handle the derivative of the callback time and discontinuity size if there is a parameter inside of the condition or affect! function. That is this issue: https://github.com/JuliaDiffEq/DiffEqSensitivity.jl/issues/4 and it's not easy, but it's doable. I hope we get around to it this summer.

Your case doesn't actually have a parameter in the condition or affect, so it should "just work". It seems like this is an issue with how PeriodicCallback is setting the tstops though, so I'll investigate that.

ChrisRackauckas commented 4 years ago

Together https://github.com/JuliaDiffEq/DiffEqCallbacks.jl/pull/74 + https://github.com/JuliaDiffEq/DiffEqSensitivity.jl/pull/200 fixes InterpolatingAdjoint.