SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
870 stars 157 forks source link

Issue with demo code on the README.md for Training a Neural ODE #206

Closed wuphilipp closed 4 years ago

wuphilipp commented 4 years ago

Hi! I'm trying to run the Nerual ODE demo code and have run into this issue.

Specifically, when I run the demo code under the section Training a Neural Ordinary Differential Equation (exact code block below, directly copied from the README.md), I get the following error

MethodError: no method matching isless(::Tuple{Float32,RecursiveArrayTools.DiffEqArray{Float32,2,Array{Array{Float32,1},1},StepRangeLen{Float32,Float64,Float64}}}, ::Float32)
Closest candidates are:
  isless(!Matched::Float32, ::Float32) at float.jl:464
  isless(!Matched::Missing, ::Any) at missing.jl:87
  isless(::Tuple, !Matched::Tuple{}) at tuple.jl:358
  ...

Stacktrace:
 [1] <(::Tuple{Float32,RecursiveArrayTools.DiffEqArray{Float32,2,Array{Array{Float32,1},1},StepRangeLen{Float32,Float64,Float64}}}, ::Float32) at ./operators.jl:268
 [2] macro expansion at /home/phil/.julia/packages/DiffEqFlux/YKKwl/src/train.jl:113 [inlined]
 [3] macro expansion at /home/phil/.julia/packages/ProgressLogging/g8xnW/src/ProgressLogging.jl:328 [inlined]
 [4] (::DiffEqFlux.var"#23#28"{var"#59#61",Int64,Bool,Bool,typeof(loss_n_ode),Array{Float32,1},Zygote.Params})() at /home/phil/.julia/packages/DiffEqFlux/YKKwl/src/train.jl:43
 [5] with_logstate(::DiffEqFlux.var"#23#28"{var"#59#61",Int64,Bool,Bool,typeof(loss_n_ode),Array{Float32,1},Zygote.Params}, ::Base.CoreLogging.LogState) at ./logging.jl:398
 [6] with_logger at ./logging.jl:505 [inlined]
 [7] maybe_with_logger(::DiffEqFlux.var"#23#28"{var"#59#61",Int64,Bool,Bool,typeof(loss_n_ode),Array{Float32,1},Zygote.Params}, ::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{ConsoleProgressMonitor.ProgressLogger,DiffEqBase.var"#10#12"},LoggingExtras.EarlyFilteredLogger{Base.CoreLogging.SimpleLogger,DiffEqBase.var"#11#13"}}}) at /home/phil/.julia/packages/DiffEqBase/k3AhB/src/utils.jl:259
 [8] sciml_train(::Function, ::Array{Float32,1}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at /home/phil/.julia/packages/DiffEqFlux/YKKwl/src/train.jl:42
 [9] top-level scope at In[9]:45

I'm using the latest stable verion of julia (1.4.0), running on ubuntu 16.04.

Any advice?

Really cool work by the way!


Code block I'm running

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0,1.5f0)

function trueODEfunc(du,u,p,t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))

dudt2 = FastChain((x,p) -> x.^3,
            FastDense(2,50,tanh),
            FastDense(50,2))
n_ode = NeuralODE(dudt2,tspan,Tsit5(),saveat=t)

function predict_n_ode(p)
  n_ode(u0,p)
end

function loss_n_ode(p)
    pred = predict_n_ode(p)
    loss = sum(abs2,ode_data .- pred)
    loss,pred
end

loss_n_ode(n_ode.p) # n_ode.p stores the initial parameters of the neural ODE

cb = function (p,l,pred;doplot=false) #callback function to observe training
  display(l)
  # plot current prediction against data
  if doplot
    pl = scatter(t,ode_data[1,:],label="data")
    scatter!(pl,t,pred[1,:],label="prediction")
    display(plot(pl))
  end
  return false
end

# Display the ODE with the initial parameter values.
cb(n_ode.p,loss_n_ode(n_ode.p)...)

res1 = DiffEqFlux.sciml_train(loss_n_ode, n_ode.p, ADAM(0.05), cb = cb, maxiters = 300)
cb(res1.minimizer,loss_n_ode(res1.minimizer)...;doplot=true)
res2 = DiffEqFlux.sciml_train(loss_n_ode, res1.minimizer, LBFGS(), cb = cb)
cb(res2.minimizer,loss_n_ode(res2.minimizer)...;doplot=true)
ChrisRackauckas commented 4 years ago

Just ran it and it ran fine. Possibly the issue is your package versions. If you do ]st what does it show you?

wuphilipp commented 4 years ago

Thanks for the reply. This is what I get from ]st.

(@v1.4) pkg> st
Status `~/.julia/environments/v1.4/Project.toml`
  [31a5f54b] Debugger v0.6.4
  [aae7a2af] DiffEqFlux v1.8.0
  [41bf760c] DiffEqSensitivity v6.10.0
  [0c46a032] DifferentialEquations v6.12.0
  [587475ba] Flux v0.10.3
  [7073ff75] IJulia v1.21.1
  [429524aa] Optim v0.20.5
  [1dea7af3] OrdinaryDiffEq v5.32.1
  [91a5bcdd] Plots v0.29.9
  [37e2e46d] LinearAlgebra

Does this look OK?

ChrisRackauckas commented 4 years ago

Looks like DiffEqFlux 1.8 had an issue, so I just put out a patch 1.8.1. If you update that should be good. Sorry for the hiccup.

wuphilipp commented 4 years ago

No worries, thanks for the quick patch!