I’m trying to use the DiffEqFlux package to define a machine learning model for an ODE. The model I need is a Neural ODE whose input is the initial condition and a continuos time signal.


Currently, i have been capable to train the NN with just one example with:

dydt  = FastChain(FastDense(3,  10), FastDense(10,  2))
dydtₚ = initial_params(dydt)

function Node(du, u, p, t)

    f_t    = p[1](t)
    params = p[2]

    du[1] = dydt([u; f_t], params)[1]
    du[2] = dydt([u; f_t], params)[2]

function prediction(fₜ, θ)
    prob =  ODEProblem(Node, u0, tspan, [fₜ, θ])
    concrete_solve(prob, Tsit5(), saveat = Δt)

loss_fn = sum(abs2, solution - prediction(f_t, dydtₚ))
res = DiffEqFlux.sciml_train(loss_fn, dydtₚ, LBFGS())

That works prefectly, but it uses just one signal for training. When i try to train in a batch of functions (using several signals) with:

function loss_batch(θ)

    n_exam   = length(signals)
    loss_val = 0

    for i in 1:n_exam
        guess     = prediction(signal[i], θ)
        loss_val += sum(abs2, solution[i] - guess)


res = DiffEqFlux.sciml_train(loss_batch, dydtₚ, LBFGS())

I get this message:

ERROR: LoadError: MethodError: no method matching AbstractFloat(::var"#19#29")
Closest candidates are:
  AbstractFloat(::Bool) at float.jl:258 
  AbstractFloat(::Int8) at float.jl:259 
  AbstractFloat(::Int16) at float.jl:260

I think that the bug is produced by the cycle, but i'm not sure

ChrisRackauckas commented 4 years ago

The right thing to use is probably a closure. Here's an example of that:

using DiffEqFlux, OrdinaryDiffEq, Flux
dydt  = FastChain(FastDense(3,  10), FastDense(10,  2))
dydtₚ = initial_params(dydt)

function Node(u, p, t, f_t)
    dydt([u; f_t(t)], p)

function prediction(fₜ, θ)
    prob =  ODEProblem((u,p,t)->Node(u,p,t,fₜ), u0, tspan, θ)
    concrete_solve(prob, Tsit5(), saveat = Δt)
solution = ones(2,11)

f_t = (t)->t;
u0 = ones(2)
Δt = 0.1
tspan = (0.0,1.0)
loss(dydtₚ) = sum(abs2, solution - prediction(f_t, dydtₚ))
res = DiffEqFlux.sciml_train(loss, dydtₚ, LBFGS())

solution = [ones(2,11) for i in 1:10]
signals = [(t)->-i*t for i in 1:10]
function loss_batch(θ)

    n_exam   = length(signals)
    loss_val = 0

    for i in 1:n_exam
        guess     = prediction(signals[i], θ)
        loss_val += sum(abs2, solution[i] - guess)


res = DiffEqFlux.sciml_train(loss_batch, dydtₚ, LBFGS())

Let me know if you need anything else. Cheers.

junsebas97 commented 4 years ago

Hi, I tryed to implement closure in my code and run your example too, but always i get this error

LoadError: MethodError: no method matching iterate(::Val{1})
Closest candidates are:
  iterate(!Matched::Core.SimpleVector) at essentials.jl:603
  iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603
  iterate(!Matched::ExponentialBackOff) at error.jl:253

Actually, with the closure the train can't be performed with just one signal

ChrisRackauckas commented 4 years ago

Are you saying that example code didn't work for you? Or are you talking about some other code/

junsebas97 commented 4 years ago

Both, i tried to implement the closure in other code and it didn't work, then i wrote literally the example you gave me and it didn't work.

Always i get the same message

LoadError: MethodError: no method matching iterate(::Val{1})
Closest candidates are:
  iterate(!Matched::Core.SimpleVector) at essentials.jl:603
  iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603
  iterate(!Matched::ExponentialBackOff) at error.jl:25
ChrisRackauckas commented 4 years ago

Can you show ]st and ]st -m? I was testing that from Julia v1.4.1

junsebas97 commented 4 years ago

I use JuliaPro 1.4.0-1


ChrisRackauckas commented 4 years ago

Your Zygote version is much behind. Try updating that and see if that fixes it.

junsebas97 commented 4 years ago

Yeah, it works with the current version of Zygote. Thanks!!

A final question, as the signal varies in time it must be evaluated at every step. does it happens with the closure?

ChrisRackauckas commented 4 years ago

A final question, as the signal varies in time it must be evaluated at every step. does it happens with the closure?

Since you're enclosing the function.

junsebas97 commented 4 years ago

Ok thanks!