Closed scheidan closed 5 years ago
We didn't add a high level function that allows this kind of mixing only because we didn't think of doing it. It wouldn't be difficult though, so if this is a feature request we could put it right in.
This would be very interesting! Sometimes you have good understanding of a part of a system (e.g. a mass balance), while for other parts you only suspect some influence but you can't express it with a parametric equation (which is a ideal use case for a ANN).
Not sure if I can help with the implementation, but I'd be very happy to test it.
I'm very interested by this feature too.
Sometimes you have good understanding of a part of a system (e.g. a mass balance), while for other parts you only suspect some influence but you can't express it with a parametric equation (which is a ideal use case for a ANN).
I think one big use for this is actually differential-algebraic equations. You might want to hardcode the expressions for the conservation laws and let it learn the dynamics.
Yes, you're right, mass balance was not the best example.
What I had in mind is to replace a part of a "physically based" equation with an ANN. For example a scientist may have a predator-prey model
drabbit/dt = a*rabbit - b*rabbit*fox
dfox/dt = c*rabbit*fox - d*fox
that she really likes. However, she has the suspicion that the "hunting part" should include time somehow as well. So she could then fit a model like this:
drabbit/dt = a*rabbit - ANN(rabbit,fox,t)
dfox/dt = c*rabbit*fox - d*fox
and analyze the the trained ANN to construct a new scientific hypothesis.
If we want all of that flexibility, it may make more sense to just thoroughly document how to do the embeddings and then let people take full control over it. I'll try and get that done tonight, with examples.
@MikeInnes Let's talk about the backprop here. Do I need to restructure or something like that?
Adjoint:
using DiffEqFlux, Flux, OrdinaryDiffEq
x = Float32[2.; 0.]
tspan = (0.0f0,25.0f0)
ann = Chain(Dense(2,10,tanh), Dense(10,1))
p1 = DiffEqFlux.destructure(ann)
_p2 = Float32[-2.0,1.1]
p2 = param(_p2)
ps = Flux.Params([p1,p2])
function dudt_(du,u::TrackedArray,p,t)
x, y = u
du[1] = ann(u)[1]
du[2] = p2[1]*y + p2[2]*x*y
end
function dudt_(du,u::AbstractArray,p,t)
x, y = u
du[1] = Flux.data(ann(u))[1]
du[2] = _p2[1]*y + _p2[2]*x*y
end
prob = ODEProblem(dudt_,x,tspan,[Flux.data(p1);Flux.data(p2)])
diffeq_adjoint([Flux.data(p1);Flux.data(p2)],prob,Tsit5())
function predict_rd()
Flux.Tracker.collect(diffeq_rd([Flux.data(p1);Flux.data(p2)],prob,Tsit5()))
end
loss_rd() = sum(abs2,x-1 for x in predict_rd())
loss_rd()
grads = Tracker.gradient(loss_rd, ps, nest=true)
grads[p1]
data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
display(loss_rd())
#display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end
# Display the ODE with the current parameter values.
cb()
Flux.train!(loss_rd, ps, data, opt, cb = cb)
Reverse-mode Autodiff:
using DiffEqFlux, Flux, OrdinaryDiffEq
x = Float32[2.; 0.]
tspan = (0.0f0,25.0f0)
ann = Chain(Dense(2,10,tanh), Dense(10,1))
p1 = DiffEqFlux.destructure(ann)
p2 = param(Float32[-2.0,1.1])
ps = Flux.Params([p1,p2])
function dudt_(u::TrackedArray,p,t)
x, y = u
Flux.Tracker.collect([ann(u)[1],p2[1]*y + p2[2]*x*y])
end
function dudt_(u::AbstractArray,p,t)
x, y = u
[Flux.data(ann(u)),p2[1]*y + p2[2]*x*y]
end
prob = ODEProblem(dudt_,x,tspan,ps)
Flux.Tracker.collect(diffeq_rd(p1,prob,Tsit5()))
function predict_rd()
Flux.Tracker.collect(diffeq_rd(p1,prob,Tsit5()))
end
loss_rd() = sum(abs2,x-1 for x in predict_rd())
loss_rd()
grads = Tracker.gradient(loss_rd, ps, nest=true)
grads[p2]
data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
display(loss_rd())
#display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end
# Display the ODE with the current parameter values.
cb()
Flux.train!(loss_rd, ps, data, opt, cb = cb)
The gradients seem off on both.
diffeq_rd
should just work with this out of the box. Destructure/restructure are just hacks for neural_ode, with reverse mode it's not necessary and you can just pass nothing
for the parameters.
You definitely don't want to use Flux.data
to work around errors, because that's going to drop gradients.
If you want to use the adjoint method / neural ODE it's pretty straightforward to make OP's dudt
a Flux layer. Something like
struct Foo
ann::Chain
end
Flux.@treelike Foo
(f::Foo)(x) = [ann(u)[1], -2.0*u[2] + 1.1*u[1]*u[2]]
(not tested code but you get the picture.)
diffeq_rd
should just work with this out of the box. Destructure/restructure are just hacks for neural_ode, with reverse mode it's not necessary and you can just passnothing
for the parameters.
We should clean up the neural_ode_rd code then. It should just work, I agree.
(f::Foo)(x) = [ann(u)[1], -2.0u[2] + 1.1u[1]*u[2]]
I think you need to Flux.tracker.collect(...)
that back into a TrackedArray instead of an Array{TrackedReal}. Then that's what I have above. But something isn't turning out right with that (likely I need to drop the AbstractArray dispatch?)
Does Tracker.collect drop previous tape information? Is that then not a good way to use out of place here?
Here's where I am at:
using DiffEqFlux, Flux, OrdinaryDiffEq
x = Float32[2.; 0.]
tspan = (0.0f0,25.0f0)
ann = Chain(Dense(2,10,tanh), Dense(10,1))
p = param(Float32[-2.0,1.1])
function dudt_(u::TrackedArray,p,t)
x, y = u
Flux.Tracker.collect([ann(u)[1],p2[1]*y + p2[2]*x*y])
end
function dudt_(u::AbstractArray,p,t)
x, y = u
[Flux.data(ann(u)),p2[1]*y + p2[2]*x*y]
end
prob = ODEProblem(dudt_,x,tspan,p)
Flux.Tracker.collect(diffeq_rd(p,prob,Tsit5()))
function predict_rd()
Flux.Tracker.collect(diffeq_rd(p,prob,Tsit5()))
end
loss_rd() = sum(abs2,x-1 for x in predict_rd())
loss_rd()
data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
display(loss_rd())
#display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end
# Display the ODE with the current parameter values.
cb()
Flux.train!(loss_rd, nothing, data, opt, cb = cb)
I got it working for Tracker and adjoints. Here's the code. Added as a test.
using DiffEqFlux, Flux, OrdinaryDiffEq
x = Float32[0.8; 0.8]
tspan = (0.0f0,25.0f0)
ann = Chain(Dense(2,10,tanh), Dense(10,1))
p = param(Float32[-2.0,1.1])
function dudt_(u::TrackedArray,p,t)
x, y = u
Flux.Tracker.collect([ann(u)[1],p[1]*y + p[2]*x])
end
function dudt_(u::AbstractArray,p,t)
x, y = u
[Flux.data(ann(u)[1]),p[1]*y + p[2]*x*y]
end
prob = ODEProblem(dudt_,x,tspan,p)
diffeq_rd(p,prob,Tsit5())
_x = param(x)
function predict_rd()
Flux.Tracker.collect(diffeq_rd(p,prob,Tsit5(),u0=_x))
end
loss_rd() = sum(abs2,x-1 for x in predict_rd())
loss_rd()
data = Iterators.repeated((), 10)
opt = ADAM(0.1)
cb = function ()
display(loss_rd())
#display(plot(solve(remake(prob,u0=Flux.data(_x),p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
end
# Display the ODE with the current parameter values.
cb()
Flux.train!(loss_rd, params(ann,p,_x), data, opt, cb = cb)
## Partial Neural Adjoint
u0 = param(Float32[0.8; 0.8])
tspan = (0.0f0,25.0f0)
ann = Chain(Dense(2,10,tanh), Dense(10,1))
p1 = Flux.data(DiffEqFlux.destructure(ann))
p2 = Float32[-2.0,1.1]
p3 = param([p1;p2])
ps = Flux.params(p3,u0)
function dudt_(du,u,p,t)
x, y = u
du[1] = DiffEqFlux.restructure(ann,p[1:41])(u)[1]
du[2] = p[end-1]*y + p[end]*x
end
prob = ODEProblem(dudt_,u0,tspan,p3)
diffeq_adjoint(p3,prob,Tsit5(),u0=u0,abstol=1e-8,reltol=1e-6)
function predict_adjoint()
diffeq_adjoint(p3,prob,Tsit5(),u0=u0,saveat=0.0:0.1:25.0)
end
loss_adjoint() = sum(abs2,x-1 for x in predict_adjoint())
loss_adjoint()
data = Iterators.repeated((), 10)
opt = ADAM(0.1)
cb = function ()
display(loss_adjoint())
#display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6)))
end
# Display the ODE with the current parameter values.
cb()
Flux.train!(loss_adjoint, ps, data, opt, cb = cb)
Have fun!
Thanks for this beautiful package!
From the Readme and the blog I understood i) how to use a "normal" ODE as Flux layer (
diffeq_rd
), and ii) how to use a Flux layer to define an ODE (neural_ode
).I was wondering if already an API exists to mix both approaches. For example to define something like this: