SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
851 stars 154 forks source link

How to mix "normal" ODEs and Neuronal ODEs? #15

Closed scheidan closed 5 years ago

scheidan commented 5 years ago

Thanks for this beautiful package!

From the Readme and the blog I understood i) how to use a "normal" ODE as Flux layer (diffeq_rd), and ii) how to use a Flux layer to define an ODE (neural_ode).

I was wondering if already an API exists to mix both approaches. For example to define something like this:

function dudt(u)
    x, y = u
    ann = Chain(Dense(2,10,tanh),  Dense(10,1))
    du[1] = ann(u)                     
    du[2] = -2.0*y + 1.1*x*y   
end
ChrisRackauckas commented 5 years ago

We didn't add a high level function that allows this kind of mixing only because we didn't think of doing it. It wouldn't be difficult though, so if this is a feature request we could put it right in.

scheidan commented 5 years ago

This would be very interesting! Sometimes you have good understanding of a part of a system (e.g. a mass balance), while for other parts you only suspect some influence but you can't express it with a parametric equation (which is a ideal use case for a ANN).

Not sure if I can help with the implementation, but I'd be very happy to test it.

aussetg commented 5 years ago

I'm very interested by this feature too.

ChrisRackauckas commented 5 years ago

Sometimes you have good understanding of a part of a system (e.g. a mass balance), while for other parts you only suspect some influence but you can't express it with a parametric equation (which is a ideal use case for a ANN).

I think one big use for this is actually differential-algebraic equations. You might want to hardcode the expressions for the conservation laws and let it learn the dynamics.

scheidan commented 5 years ago

Yes, you're right, mass balance was not the best example.

What I had in mind is to replace a part of a "physically based" equation with an ANN. For example a scientist may have a predator-prey model

drabbit/dt = a*rabbit - b*rabbit*fox
dfox/dt    = c*rabbit*fox - d*fox

that she really likes. However, she has the suspicion that the "hunting part" should include time somehow as well. So she could then fit a model like this:

drabbit/dt = a*rabbit - ANN(rabbit,fox,t)
dfox/dt    = c*rabbit*fox - d*fox

and analyze the the trained ANN to construct a new scientific hypothesis.

ChrisRackauckas commented 5 years ago

If we want all of that flexibility, it may make more sense to just thoroughly document how to do the embeddings and then let people take full control over it. I'll try and get that done tonight, with examples.

ChrisRackauckas commented 5 years ago

@MikeInnes Let's talk about the backprop here. Do I need to restructure or something like that?

Adjoint:

using DiffEqFlux, Flux, OrdinaryDiffEq

x = Float32[2.; 0.]
tspan = (0.0f0,25.0f0)

ann = Chain(Dense(2,10,tanh), Dense(10,1))

p1 = DiffEqFlux.destructure(ann)
_p2 = Float32[-2.0,1.1]
p2 = param(_p2)
ps = Flux.Params([p1,p2])

function dudt_(du,u::TrackedArray,p,t)
    x, y = u
    du[1] = ann(u)[1]
    du[2] = p2[1]*y + p2[2]*x*y
end
function dudt_(du,u::AbstractArray,p,t)
    x, y = u
    du[1] = Flux.data(ann(u))[1]
    du[2] = _p2[1]*y + _p2[2]*x*y
end
prob = ODEProblem(dudt_,x,tspan,[Flux.data(p1);Flux.data(p2)])
diffeq_adjoint([Flux.data(p1);Flux.data(p2)],prob,Tsit5())

function predict_rd()
  Flux.Tracker.collect(diffeq_rd([Flux.data(p1);Flux.data(p2)],prob,Tsit5()))
end
loss_rd() = sum(abs2,x-1 for x in predict_rd())
loss_rd()

grads = Tracker.gradient(loss_rd, ps, nest=true)
grads[p1]

data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
  display(loss_rd())
  #display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end

# Display the ODE with the current parameter values.
cb()

Flux.train!(loss_rd, ps, data, opt, cb = cb)

Reverse-mode Autodiff:

using DiffEqFlux, Flux, OrdinaryDiffEq

x = Float32[2.; 0.]
tspan = (0.0f0,25.0f0)

ann = Chain(Dense(2,10,tanh), Dense(10,1))

p1 = DiffEqFlux.destructure(ann)
p2 = param(Float32[-2.0,1.1])
ps = Flux.Params([p1,p2])

function dudt_(u::TrackedArray,p,t)
    x, y = u
    Flux.Tracker.collect([ann(u)[1],p2[1]*y + p2[2]*x*y])
end
function dudt_(u::AbstractArray,p,t)
    x, y = u
    [Flux.data(ann(u)),p2[1]*y + p2[2]*x*y]
end

prob = ODEProblem(dudt_,x,tspan,ps)
Flux.Tracker.collect(diffeq_rd(p1,prob,Tsit5()))

function predict_rd()
  Flux.Tracker.collect(diffeq_rd(p1,prob,Tsit5()))
end
loss_rd() = sum(abs2,x-1 for x in predict_rd())
loss_rd()

grads = Tracker.gradient(loss_rd, ps, nest=true)
grads[p2]

data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
  display(loss_rd())
  #display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end

# Display the ODE with the current parameter values.
cb()

Flux.train!(loss_rd, ps, data, opt, cb = cb)

The gradients seem off on both.

MikeInnes commented 5 years ago

diffeq_rd should just work with this out of the box. Destructure/restructure are just hacks for neural_ode, with reverse mode it's not necessary and you can just pass nothing for the parameters.

You definitely don't want to use Flux.data to work around errors, because that's going to drop gradients.

If you want to use the adjoint method / neural ODE it's pretty straightforward to make OP's dudt a Flux layer. Something like

struct Foo
  ann::Chain
end

Flux.@treelike Foo

(f::Foo)(x) = [ann(u)[1], -2.0*u[2] + 1.1*u[1]*u[2]]

(not tested code but you get the picture.)

ChrisRackauckas commented 5 years ago

diffeq_rd should just work with this out of the box. Destructure/restructure are just hacks for neural_ode, with reverse mode it's not necessary and you can just pass nothing for the parameters.

We should clean up the neural_ode_rd code then. It should just work, I agree.

(f::Foo)(x) = [ann(u)[1], -2.0u[2] + 1.1u[1]*u[2]]

I think you need to Flux.tracker.collect(...) that back into a TrackedArray instead of an Array{TrackedReal}. Then that's what I have above. But something isn't turning out right with that (likely I need to drop the AbstractArray dispatch?)

ChrisRackauckas commented 5 years ago

Does Tracker.collect drop previous tape information? Is that then not a good way to use out of place here?

Here's where I am at:

using DiffEqFlux, Flux, OrdinaryDiffEq

x = Float32[2.; 0.]
tspan = (0.0f0,25.0f0)

ann = Chain(Dense(2,10,tanh), Dense(10,1))
p = param(Float32[-2.0,1.1])

function dudt_(u::TrackedArray,p,t)
    x, y = u
    Flux.Tracker.collect([ann(u)[1],p2[1]*y + p2[2]*x*y])
end
function dudt_(u::AbstractArray,p,t)
    x, y = u
    [Flux.data(ann(u)),p2[1]*y + p2[2]*x*y]
end

prob = ODEProblem(dudt_,x,tspan,p)
Flux.Tracker.collect(diffeq_rd(p,prob,Tsit5()))

function predict_rd()
  Flux.Tracker.collect(diffeq_rd(p,prob,Tsit5()))
end
loss_rd() = sum(abs2,x-1 for x in predict_rd())
loss_rd()

data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
  display(loss_rd())
  #display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end

# Display the ODE with the current parameter values.
cb()

Flux.train!(loss_rd, nothing, data, opt, cb = cb)
ChrisRackauckas commented 5 years ago

I got it working for Tracker and adjoints. Here's the code. Added as a test.

using DiffEqFlux, Flux, OrdinaryDiffEq

x = Float32[0.8; 0.8]
tspan = (0.0f0,25.0f0)

ann = Chain(Dense(2,10,tanh), Dense(10,1))
p = param(Float32[-2.0,1.1])

function dudt_(u::TrackedArray,p,t)
    x, y = u
    Flux.Tracker.collect([ann(u)[1],p[1]*y + p[2]*x])
end
function dudt_(u::AbstractArray,p,t)
    x, y = u
    [Flux.data(ann(u)[1]),p[1]*y + p[2]*x*y]
end

prob = ODEProblem(dudt_,x,tspan,p)
diffeq_rd(p,prob,Tsit5())
_x = param(x)

function predict_rd()
  Flux.Tracker.collect(diffeq_rd(p,prob,Tsit5(),u0=_x))
end
loss_rd() = sum(abs2,x-1 for x in predict_rd())
loss_rd()

data = Iterators.repeated((), 10)
opt = ADAM(0.1)
cb = function ()
  display(loss_rd())
  #display(plot(solve(remake(prob,u0=Flux.data(_x),p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end

# Display the ODE with the current parameter values.
cb()

Flux.train!(loss_rd, params(ann,p,_x), data, opt, cb = cb)

## Partial Neural Adjoint

u0 = param(Float32[0.8; 0.8])
tspan = (0.0f0,25.0f0)

ann = Chain(Dense(2,10,tanh), Dense(10,1))

p1 = Flux.data(DiffEqFlux.destructure(ann))
p2 = Float32[-2.0,1.1]
p3 = param([p1;p2])
ps = Flux.params(p3,u0)

function dudt_(du,u,p,t)
    x, y = u
    du[1] = DiffEqFlux.restructure(ann,p[1:41])(u)[1]
    du[2] = p[end-1]*y + p[end]*x
end
prob = ODEProblem(dudt_,u0,tspan,p3)
diffeq_adjoint(p3,prob,Tsit5(),u0=u0,abstol=1e-8,reltol=1e-6)

function predict_adjoint()
  diffeq_adjoint(p3,prob,Tsit5(),u0=u0,saveat=0.0:0.1:25.0)
end
loss_adjoint() = sum(abs2,x-1 for x in predict_adjoint())
loss_adjoint()

data = Iterators.repeated((), 10)
opt = ADAM(0.1)
cb = function ()
  display(loss_adjoint())
  #display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6)))
end

# Display the ODE with the current parameter values.
cb()

Flux.train!(loss_adjoint, ps, data, opt, cb = cb)

Have fun!