SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
864 stars 154 forks source link

ODE_layer #155

Closed gabrevaya closed 4 years ago

gabrevaya commented 4 years ago

First of all, thanks a lot for this amazing package :)

I'm trying to implement a model with usual Flux layers chained with a differential equation layer. It would be like a NeuralODE in between regular nn layers, but instead of letting the NeuralODE approximate the DE, I want to give it some structure. I'm not managing to get it to work even with the simplest case. This is a sample code:

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, DiffEqSensitivity

abstract type NeuralDELayer <: Function end

struct ODE_Layer{M,P} <: NeuralDELayer
    model::M
    p::P

    function ODE_Layer(model,p)
        new{typeof(model),typeof(p)}(
            model,p)
    end
end

Flux.@functor ODE_Layer

function (n::ODE_Layer)(x::AbstractArray)
    tspan = (0.0f0,10.0f0)
    prob = ODEProblem{false}(n.model,x,tspan,n.p)
    last(concrete_solve(prob,Tsit5(),x,n.p,save_everystep=false,saveat=tspan[2],sensealg=TrackerAdjoint()))
end

struct NeuralStructODE{M,P,RE} <: NeuralDELayer
    model::M
    p::P
    re::RE

    function NeuralStructODE(model)
        p,re = Flux.destructure(model)
        new{typeof(model),typeof(p),typeof(re)}(
            model,p,re)
    end
end

Flux.@functor NeuralStructODE

function (n::NeuralStructODE)(x,p=n.p)
    return n.re(p)(x)
end

# In the beginning, I had defined the ODE_Layer with tspan, solver, args and kwarg, as
# in NeuralODE struc, but then calling NeuralStructODE(u0) gave error. After digging a
# lot I found that Flux.destructurate was messing up the arguments of ODE_Layer when
# calling n.re(p)(x). This is another problem I'd like to solve too. However, I'd like
# to find first some working version of the code, even if it requires fixing everything!

function lotka_volterra(u,p,t)
  x, y = u
  α, β, δ, γ = p
  dx = α*x - β*x*y
  dy = -δ*y + γ*x*y
  return [dx, dy]
end
u0 = Float32[1.0,1.0]
tspan = (0.0f0,10.0f0)
p = Float32[1.5,1.0,3.0,1.0]

ode_layer = ODE_Layer(lotka_volterra,p)

model = Chain(Dense(2,10,tanh),
              Dense(10,2,tanh),
              ode_layer,
              Dense(2,10,tanh),
              Dense(10,2,tanh))

model_struct = NeuralStructODE(model)

model_struct(u0)
model_struct(u0, model_struct.p)

function predict(p)
  model_struct(u0,p)
end

function loss(p)
    pred = predict(p)
    loss = sum(abs2, Float32[1.0,1.0] .- pred)
    loss,pred
end

loss(model_struct.p)

cb = function (p,l,pred)
  display(l)
  return false
end

cb(model_struct.p,loss(model_struct.p)...)

res1 = DiffEqFlux.sciml_train(loss, model_struct.p, ADAM(0.05), cb = cb, maxiters = 300)

I get the following error:

ERROR: Mutating arrays is not supported
Stacktrace:
 [1] (::Zygote.var"#1007#1008")(::Nothing) at /Users/ger/.julia/packages/Zygote/tJj2w/src/lib/array.jl:49
 [2] (::Zygote.var"#2682#back#1009"{Zygote.var"#1007#1008"})(::Nothing) at /Users/ger/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [3] getindex at ./array.jl:369 [inlined]
 [4] (::typeof(∂(getindex)))(::Array{Float32,1}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [5] (::typeof(∂(loss)))(::Tuple{Float32,Nothing}) at /Users/ger/Documents/ODE_layer/for_issue.jl:76
 [6] #18 at /Users/ger/.julia/dev/DiffEqFlux/src/train.jl:24 [inlined]
 [7] (::typeof(∂(λ)))(::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [8] (::Zygote.var"#46#47"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:101
 [9] gradient(::Function, ::Zygote.Params) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:47
 [10] macro expansion at /Users/ger/.julia/dev/DiffEqFlux/src/train.jl:23 [inlined]
 [11] macro expansion at /Users/ger/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined]
 [12] #sciml_train#16(::Function, ::Int64, ::typeof(DiffEqFlux.sciml_train), ::typeof(loss), ::Array{Float32,1}, ::ADAM) at /Users/ger/.julia/dev/DiffEqFlux/src/train.jl:22
 [13] (::DiffEqFlux.var"#kw##sciml_train")(::NamedTuple{(:cb, :maxiters),Tuple{var"#15#16",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::ADAM) at ./none:0
 [14] top-level scope at none:0

Using BFGS gives a similar error. I don't find which mutating array it is complaining about.

Using sensealg = ForwardDiffSensitivity() in the concrete_solve leads to a different error:

ERROR: MethodError: no method matching (::Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}})(::Float32)
Closest candidates are:
  Any(::AbstractArray{T,N} where N) where {T<:Union{Float32, Float64}, W<:(AbstractArray{T,N} where N)} at /Users/ger/.julia/packages/Flux/2i5P1/src/layers/basic.jl:113
  Any(::AbstractArray{#s106,N} where N where #s106<:AbstractFloat) where {T<:Union{Float32, Float64}, W<:(AbstractArray{T,N} where N)} at /Users/ger/.julia/packages/Flux/2i5P1/src/layers/basic.jl:116
  Any(::AbstractArray) at /Users/ger/.julia/packages/Flux/2i5P1/src/layers/basic.jl:101
Stacktrace:
 [1] macro expansion at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(::Zygote.Context, ::Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}}, ::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:7
 [3] applychain at /Users/ger/.julia/packages/Flux/2i5P1/src/layers/basic.jl:30 [inlined]
 [4] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}}}, ::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [5] applychain at /Users/ger/.julia/packages/Flux/2i5P1/src/layers/basic.jl:30 [inlined]
 [6] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{ODE_Layer{typeof(lotka_volterra),Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}}}, ::Array{Float32,1}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [7] applychain at /Users/ger/.julia/packages/Flux/2i5P1/src/layers/basic.jl:30 [inlined]
 [8] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},ODE_Layer{typeof(lotka_volterra),Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}}}, ::Array{Float32,1}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [9] applychain at /Users/ger/.julia/packages/Flux/2i5P1/src/layers/basic.jl:30 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},ODE_Layer{typeof(lotka_volterra),Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}}}, ::Array{Float32,1}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [11] Chain at /Users/ger/.julia/packages/Flux/2i5P1/src/layers/basic.jl:32 [inlined]
 [12] _pullback(::Zygote.Context, ::Chain{Tuple{Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},ODE_Layer{typeof(lotka_volterra),Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}}}}, ::Array{Float32,1}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [13] NeuralStructODE at /Users/ger/Documents/ODE_layer/for_issue.jl:43 [inlined]
 [14] _pullback(::Zygote.Context, ::NeuralStructODE{Chain{Tuple{Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},ODE_Layer{typeof(lotka_volterra),Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}}}},Array{Float32,1},Flux.var"#12#14"{Chain{Tuple{Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},ODE_Layer{typeof(lotka_volterra),Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}}}}}}, ::Array{Float32,1}, ::Array{Float32,1}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [15] predict at /Users/ger/Documents/ODE_layer/for_issue.jl:71 [inlined]
 [16] _pullback(::Zygote.Context, ::typeof(predict), ::Array{Float32,1}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [17] loss at /Users/ger/Documents/ODE_layer/for_issue.jl:75 [inlined]
 [18] _pullback(::Zygote.Context, ::typeof(loss), ::Array{Float32,1}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [19] #18 at /Users/ger/.julia/dev/DiffEqFlux/src/train.jl:24 [inlined]
 [20] _pullback(::Zygote.Context, ::DiffEqFlux.var"#18#24"{typeof(loss),Array{Float32,1}}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [21] pullback(::Function, ::Zygote.Params) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:96
 [22] gradient(::Function, ::Zygote.Params) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:46
 [23] macro expansion at /Users/ger/.julia/dev/DiffEqFlux/src/train.jl:23 [inlined]
 [24] macro expansion at /Users/ger/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined]
 [25] #sciml_train#16(::Function, ::Int64, ::typeof(DiffEqFlux.sciml_train), ::typeof(loss), ::Array{Float32,1}, ::ADAM) at /Users/ger/.julia/dev/DiffEqFlux/src/train.jl:22
 [26] (::DiffEqFlux.var"#kw##sciml_train")(::NamedTuple{(:cb, :maxiters),Tuple{var"#13#14",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::ADAM) at ./none:0
 [27] top-level scope at none:0

I would truly appreciate it if you could, please, shed some light on this. What am I doing wrong? Do you have in mind another simpler way of implementing the DE layer in between regular Flux layers and make it all work?

I'm using Julia 1.3.1 and packages version: [aae7a2af] DiffEqFlux v1.3.1 [41bf760c] DiffEqSensitivity v6.6.1 [0c46a032] DifferentialEquations v6.10.1 [587475ba] Flux v0.10.1 [429524aa] Optim v0.20.1 [1dea7af3] OrdinaryDiffEq v5.28.1 [9f7883ad] Tracker v0.2.6

ChrisRackauckas commented 4 years ago

Why ODE_Layer instead of just using NeuralODE here?

I think the issue might be restructure/destructure. If you use FastChain/FastDense do you still have this issue?

gabrevaya commented 4 years ago

Thanks for your quick response! NeuralODE convert a Flux model or a FastChain model into a differential equation and then integrate it. As it is defined now, NeuralODE doesn't allow to provide a diff eq function or directly an ODEProblem for using in the integration. Maybe we could add that method to its definition. But as the name "NeuralODE" refers to a ODE aproximated by a neural network, I thought that it could be more proper to call it ODE_layer or DE_layer. Anyway, I still don't get it to work!

Following your suggestion, I tried using FastChain/FastDense and rewriting the ODE_layer a bit for making it compatible:

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, DiffEqSensitivity
import DiffEqFlux.initial_params
import DiffEqFlux.paramlength

abstract type FastLayer <: Function end

struct ODE_Layer{M,F} <: FastLayer
    model::M
    initial_params::F

    function ODE_Layer(model,p)
        initial_params() = vcat(p)
        new{typeof(model),typeof(initial_params)}(
            model,initial_params)
    end
end

Flux.@functor ODE_Layer

function (n::ODE_Layer)(x::AbstractArray,p)
    tspan = (0.0f0,10.0f0)
    prob = ODEProblem{false}(n.model,x,tspan,p)
    last(concrete_solve(prob,Tsit5(),x,p,save_everystep=false,saveat=tspan[2]))
end

paramlength(n::ODE_Layer) = length(n.initial_params())
initial_params(n::ODE_Layer) = n.initial_params()

function lotka_volterra(u,p,t)
  x, y = u
  α, β, δ, γ = p
  dx = α*x - β*x*y
  dy = -δ*y + γ*x*y
  return [dx, dy]
end
u0 = Float32[1.0,1.0]
p = Float32[1.5,1.0,3.0,1.0]
ode_layer = ODE_Layer(lotka_volterra,p)

model = FastChain(FastDense(2,10,tanh),
              FastDense(10,2),
              ode_layer,
              FastDense(2,10,tanh),
              FastDense(10,2))

function loss(p)
  pred = model(u0,p)
  loss = sum(abs2, Float32[1.0,0.0] .- pred)
  loss,pred
end

loss(initial_params(model)) # OK

cb = function (p,l,pred)
  display(l)
  return false
end

res1 = DiffEqFlux.sciml_train(loss, initial_params(model), ADAM(0.05), cb = cb, maxiters = 300)

But I still get the Mutating arrays error:

ERROR: Mutating arrays is not supported
Stacktrace:
 [1] (::Zygote.var"#1007#1008")(::Nothing) at /Users/ger/.julia/packages/Zygote/tJj2w/src/lib/array.jl:49
 [2] (::Zygote.var"#2682#back#1009"{Zygote.var"#1007#1008"})(::Nothing) at /Users/ger/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [3] getindex at ./array.jl:369 [inlined]
 [4] (::typeof(∂(getindex)))(::Array{Float32,1}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [5] (::typeof(∂(loss)))(::Tuple{Float32,Nothing}) at /Users/ger/Documents/ODE_layer/for_issue3.jl:48
 [6] #18 at /Users/ger/.julia/packages/DiffEqFlux/4NwUK/src/train.jl:24 [inlined]
 [7] (::typeof(∂(λ)))(::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [8] (::Zygote.var"#46#47"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:101
 [9] gradient(::Function, ::Zygote.Params) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:47
 [10] macro expansion at /Users/ger/.julia/packages/DiffEqFlux/4NwUK/src/train.jl:23 [inlined]
 [11] macro expansion at /Users/ger/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined]
 [12] #sciml_train#16(::Function, ::Int64, ::typeof(DiffEqFlux.sciml_train), ::typeof(loss), ::Array{Float32,1}, ::ADAM) at /Users/ger/.julia/packages/DiffEqFlux/4NwUK/src/train.jl:22
 [13] (::DiffEqFlux.var"#kw##sciml_train")(::NamedTuple{(:cb, :maxiters),Tuple{var"#17#18",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::ADAM) at ./none:0
 [14] top-level scope at none:0

I noticed that if I use only FastDense layers

model = FastChain(FastDense(2,10,tanh),
              FastDense(10,2),
              FastDense(2,10,tanh),
              FastDense(10,2))

I get exactly the same error.

ChrisRackauckas commented 4 years ago

Aha, that last little tidbit led to the MWE:

using Zygote
f(p) = sum(Float32[1.0,0.0] - p)
Zygote.gradient(f,Float32[1.0,1.0])

So the issue is just the literal Float32[1.0,0.0] in the loss function. Write your loss as:

my_data = Float32[1.0,0.0]
function loss(p)
  pred = model(u0,p)
  loss = sum(abs2,  my_data  .- pred)
  loss,pred
end

and you should be good. Looks like another case of https://github.com/FluxML/Zygote.jl/issues/513, so we'll get this fixed up in Zygote!

gabrevaya commented 4 years ago

Thanks a lot Chris!! :)

Actually it was still throwing an error when the model contained the ode_layer but I realized the problem was because of applying last(...) on the concrete_solve(...) output. For some reason, Zygote doesn't seem to like function last. Using Array(...)[:,end] solved the issue!

ChrisRackauckas commented 4 years ago

Can you try and make an MWE for that?

using Zygote
f(p) = last(p)
Zygote.gradient(f,Float32[1.0,1.0])

this works, so I can't recreate this issue.

gabrevaya commented 4 years ago

I think that the problem is not with last alone but with the usage of last together with saveat in concrete_solve:

using Zygote, OrdinaryDiffEq

function lotka_volterra(u,p,t)
  x, y = u
  α, β, δ, γ = p
  dx = α*x - β*x*y
  dy = -δ*y + γ*x*y
  return [dx, dy]
end

u0 = Float32[1.0,1.0]
tspan = (0.0f0,10.0f0)
p = Float32[1.5,1.0,3.0,1.0]

function f(p)
  prob = ODEProblem(lotka_volterra,u0,tspan,p)
  sum(last(concrete_solve(prob,Tsit5(),u0,p,saveat=tspan[2])))
end

Zygote.gradient(f,p)

This breaks, but the following works

function f(p)
  prob = ODEProblem(lotka_volterra,u0,tspan,p)
  sum(last(concrete_solve(prob,Tsit5(),u0,p)))
end

Zygote.gradient(f,p)

even though

prob = ODEProblem(lotka_volterra,u0,tspan,p)
last(concrete_solve(prob,Tsit5(),u0,p,saveat=tspan[2])) == last(concrete_solve(prob,Tsit5(),u0,p)) # true
ChrisRackauckas commented 4 years ago

Turned it into an MWE here: https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/160. Thanks!

gabrevaya commented 4 years ago

Thanks to you!