The right thing to use is probably a closure. Here's an example of that:
using DiffEqFlux, OrdinaryDiffEq, Flux
dydt = FastChain(FastDense(3, 10), FastDense(10, 2))
dydtₚ = initial_params(dydt)
function Node(u, p, t, f_t)
dydt([u; f_t(t)], p)
function prediction(fₜ, θ)
prob = ODEProblem((u,p,t)->Node(u,p,t,fₜ), u0, tspan, θ)
concrete_solve(prob, Tsit5(), saveat = Δt)
solution = ones(2,11)
f_t = (t)->t;
u0 = ones(2)
Δt = 0.1
tspan = (0.0,1.0)
loss(dydtₚ) = sum(abs2, solution - prediction(f_t, dydtₚ))
res = DiffEqFlux.sciml_train(loss, dydtₚ, LBFGS())
solution = [ones(2,11) for i in 1:10]
signals = [(t)->-i*t for i in 1:10]
function loss_batch(θ)
n_exam = length(signals)
loss_val = 0
for i in 1:n_exam
guess = prediction(signals[i], θ)
loss_val += sum(abs2, solution[i] - guess)
res = DiffEqFlux.sciml_train(loss_batch, dydtₚ, LBFGS())
Let me know if you need anything else. Cheers.
Hi, I tryed to implement closure in my code and run your example too, but always i get this error
LoadError: MethodError: no method matching iterate(::Val{1})
Closest candidates are:
iterate(!Matched::Core.SimpleVector) at essentials.jl:603
iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603
iterate(!Matched::ExponentialBackOff) at error.jl:253
Actually, with the closure the train can't be performed with just one signal
Are you saying that example code didn't work for you? Or are you talking about some other code/
Both, i tried to implement the closure in other code and it didn't work, then i wrote literally the example you gave me and it didn't work.
Always i get the same message
LoadError: MethodError: no method matching iterate(::Val{1})
Closest candidates are:
iterate(!Matched::Core.SimpleVector) at essentials.jl:603
iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603
iterate(!Matched::ExponentialBackOff) at error.jl:25
Can you show ]st
and ]st -m
? I was testing that from Julia v1.4.1
I use JuliaPro 1.4.0-1
Your Zygote version is much behind. Try updating that and see if that fixes it.
Yeah, it works with the current version of Zygote. Thanks!!
A final question, as the signal varies in time it must be evaluated at every step. does it happens with the closure?
A final question, as the signal varies in time it must be evaluated at every step. does it happens with the closure?
Since you're enclosing the function.
Ok thanks!
I’m trying to use the DiffEqFlux package to define a machine learning model for an ODE. The model I need is a Neural ODE whose input is the initial condition and a continuos time signal.
Currently, i have been capable to train the NN with just one example with:
That works prefectly, but it uses just one signal for training. When i try to train in a batch of functions (using several signals) with:
I get this message:
I think that the bug is produced by the cycle, but i'm not sure