SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
864 stars 154 forks source link

Derivatives of the neural_ode #7

Closed ChrisRackauckas closed 5 years ago

ChrisRackauckas commented 5 years ago

I accidentally stopped the tracking in the neural_ode layer and of course it can't backprop through since what comes out is just a float (@MikeInnes told me this would happen). So an MWE of trying to train with the neural_ode is:

using OrdinaryDiffEq, DiffEqFlux, Flux, Plots
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end
u0 = [1.0,1.0]
tspan = (0.0,10.0)
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(lotka_volterra,u0,tspan,p)
ode_data = Array(solve(prob,Tsit5(),saveat=0.1))

dudt = Chain(Dense(2,50,tanh),Dense(50,2))
tspan = (0.0f0,10.0f0)
n_ode = x->neural_ode(x,dudt,tspan,Tsit5(),saveat=0.1)
pred = n_ode(u0)

scatter(0.0:0.1:10.0,ode_data[1,:],label="data")
scatter!(0.0:0.1:10.0,pred[1,:],label="prediction")

function predict_n_ode()
  n_ode(u0)
end
data
predict_n_ode()

loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())

data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  display(loss_n_ode())
  # plot current prediction against data
  cur_pred = predict_n_ode()
  pl = scatter(0.0:0.1:10.0,ode_data[1,:],label="data")
  scatter!(pl,0.0:0.1:10.0,cur_pred[1,:],label="prediction")
  plot(pl)
end

# Display the ODE with the initial parameter values.
cb()

Flux.train!(loss_n_ode, params, data, opt, cb = cb)
ChrisRackauckas commented 5 years ago

Trains!