Closed 000Justin000 closed 4 years ago
Rosenbrock23(autodiff=false)
works. For example:
using DifferentialEquations
using Flux
using DiffEqFlux
using Plots
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0,1.5f0)
function trueODEfunc(du,u,p,t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))
dudt = Chain(x -> x.^3,
Dense(2,50,tanh),
Dense(50,2))
n_ode(x) = neural_ode(dudt,x,tspan,Rosenbrock23(autodiff=false),saveat=t,reltol=1e-7,abstol=1e-9)
function predict_n_ode()
n_ode(u0)
end
loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())
data = Iterators.repeated((), 1000)
opt = ADAM(0.1)
cb = function () #callback function to observe training
display(loss_n_ode())
# plot current prediction against data
cur_pred = Flux.data(predict_n_ode())
pl = scatter(t,ode_data[1,:],label="data")
scatter!(pl,t,cur_pred[1,:],label="prediction")
display(plot(pl))
end
# Display the ODE with the initial parameter values.
cb()
ps = Flux.params(dudt)
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)
So at least that's a workaround for now. Nesting the ADs seems to be a point that needs some work, maybe @YingboMa or @MikeInnes can pitch in.
This long standing issue has been fixed. Thanks.
Hello,
I keep getting the WARNING: Instability detected. Aborting error. when I am doing training. I think this might because the differential equation I set up is stiff, and I need an implicit solver for that. However, when I replace the default Tsit5() solver with Rosenbrock23(), I get a stack-overflow error. Would you help me with this?
Thanks! Junteng