SciML / OrdinaryDiffEq.jl

High performance ordinary differential equation (ODE) and differential-algebraic equation (DAE) solvers, including neural ordinary differential equations (neural ODEs) and scientific machine learning (SciML)
https://diffeq.sciml.ai/latest/
Other
564 stars 211 forks source link

Eigenvalue/Spectral radius estimate in the ROCK2 #846

Closed ranjanan closed 5 years ago

ranjanan commented 5 years ago

This line here will throw a DomainError when the expression in the square root is negative.

I get that error when I run the following code:

using OrdinaryDiffEq
using DiffEqFlux
using Flux
using Plots

function f(du, u, p, t)
    du[1] = -p[1]*u[1] + p[2]*u[2]*u[3]
    du[2] = p[1]*u[1] - p[2]*u[2]*u[3] - p[3]*u[2]*u[2]
    du[3] = p[3]*u[2]*u[2]
end
p = [0.04, 10^4, 3e7]
u0 = ([1.,0.,0.])
tspan = (0., 100.0)
prob = ODEProblem(f, u0, tspan, p)
sol = solve(prob, Rosenbrock23(autodiff = false), saveat = 1)

u0 = (u0)
ode_data = Array(sol)

dudt = Chain(
            Dense(3, 16, tanh),
            Dense(16, 3)
            )
n_ode(x) = neural_ode(dudt, x, tspan, ROCK2(),
        saveat = 1.0)
predict_n_ode() = n_ode(u0)
function loss_n_ode()
  sum(abs2, ode_data .- (predict_n_ode()))
end
loss_n_ode()
data = Iterators.repeated((), 1000)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  t = sol.t
  display(loss_n_ode())
  # plot current prediction against data
  cur_pred = Flux.data(predict_n_ode())
  pl = scatter(t,ode_data[1,:],label="data")
  scatter!(pl,t,collect(cur_pred[1,:]),label="prediction")
  display(plot(pl))
end
cb()

ps = Flux.params(dudt)
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)
ChrisRackauckas commented 5 years ago

@deeepeshthakur

deeepeshthakur commented 5 years ago

I think this is happening because of dt < 0. It works with abs(dt)

ChrisRackauckas commented 5 years ago

dt<0 is allowed BTW. We should have that be tdir*dt so that tspan = (1.0, 0.0) works, but that's probably a separate issue.

deeepeshthakur commented 5 years ago

Yes. I think the fix for this particular issue should just be abs(dt) in place of dt in the sqrt calculation.

deeepeshthakur commented 5 years ago

@ranjanan #848 should fix this.

ChrisRackauckas commented 5 years ago

For reference, the negative dt comes into play in the adjoint calculation here.

ranjanan commented 5 years ago

Thanks @deeepeshthakur