Closed gabrevaya closed 4 years ago
Why ODE_Layer
instead of just using NeuralODE
here?
I think the issue might be restructure/destructure. If you use FastChain
/FastDense
do you still have this issue?
Thanks for your quick response! NeuralODE
convert a Flux model or a FastChain
model into a differential equation and then integrate it. As it is defined now, NeuralODE
doesn't allow to provide a diff eq function or directly an ODEProblem for using in the integration. Maybe we could add that method to its definition. But as the name "NeuralODE" refers to a ODE aproximated by a neural network, I thought that it could be more proper to call it ODE_layer or DE_layer. Anyway, I still don't get it to work!
Following your suggestion, I tried using FastChain
/FastDense
and rewriting the ODE_layer a bit for making it compatible:
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, DiffEqSensitivity
import DiffEqFlux.initial_params
import DiffEqFlux.paramlength
abstract type FastLayer <: Function end
struct ODE_Layer{M,F} <: FastLayer
model::M
initial_params::F
function ODE_Layer(model,p)
initial_params() = vcat(p)
new{typeof(model),typeof(initial_params)}(
model,initial_params)
end
end
Flux.@functor ODE_Layer
function (n::ODE_Layer)(x::AbstractArray,p)
tspan = (0.0f0,10.0f0)
prob = ODEProblem{false}(n.model,x,tspan,p)
last(concrete_solve(prob,Tsit5(),x,p,save_everystep=false,saveat=tspan[2]))
end
paramlength(n::ODE_Layer) = length(n.initial_params())
initial_params(n::ODE_Layer) = n.initial_params()
function lotka_volterra(u,p,t)
x, y = u
α, β, δ, γ = p
dx = α*x - β*x*y
dy = -δ*y + γ*x*y
return [dx, dy]
end
u0 = Float32[1.0,1.0]
p = Float32[1.5,1.0,3.0,1.0]
ode_layer = ODE_Layer(lotka_volterra,p)
model = FastChain(FastDense(2,10,tanh),
FastDense(10,2),
ode_layer,
FastDense(2,10,tanh),
FastDense(10,2))
function loss(p)
pred = model(u0,p)
loss = sum(abs2, Float32[1.0,0.0] .- pred)
loss,pred
end
loss(initial_params(model)) # OK
cb = function (p,l,pred)
display(l)
return false
end
res1 = DiffEqFlux.sciml_train(loss, initial_params(model), ADAM(0.05), cb = cb, maxiters = 300)
But I still get the Mutating arrays error:
ERROR: Mutating arrays is not supported
Stacktrace:
[1] (::Zygote.var"#1007#1008")(::Nothing) at /Users/ger/.julia/packages/Zygote/tJj2w/src/lib/array.jl:49
[2] (::Zygote.var"#2682#back#1009"{Zygote.var"#1007#1008"})(::Nothing) at /Users/ger/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[3] getindex at ./array.jl:369 [inlined]
[4] (::typeof(∂(getindex)))(::Array{Float32,1}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
[5] (::typeof(∂(loss)))(::Tuple{Float32,Nothing}) at /Users/ger/Documents/ODE_layer/for_issue3.jl:48
[6] #18 at /Users/ger/.julia/packages/DiffEqFlux/4NwUK/src/train.jl:24 [inlined]
[7] (::typeof(∂(λ)))(::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
[8] (::Zygote.var"#46#47"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:101
[9] gradient(::Function, ::Zygote.Params) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:47
[10] macro expansion at /Users/ger/.julia/packages/DiffEqFlux/4NwUK/src/train.jl:23 [inlined]
[11] macro expansion at /Users/ger/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined]
[12] #sciml_train#16(::Function, ::Int64, ::typeof(DiffEqFlux.sciml_train), ::typeof(loss), ::Array{Float32,1}, ::ADAM) at /Users/ger/.julia/packages/DiffEqFlux/4NwUK/src/train.jl:22
[13] (::DiffEqFlux.var"#kw##sciml_train")(::NamedTuple{(:cb, :maxiters),Tuple{var"#17#18",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::ADAM) at ./none:0
[14] top-level scope at none:0
I noticed that if I use only FastDense
layers
model = FastChain(FastDense(2,10,tanh),
FastDense(10,2),
FastDense(2,10,tanh),
FastDense(10,2))
I get exactly the same error.
Aha, that last little tidbit led to the MWE:
using Zygote
f(p) = sum(Float32[1.0,0.0] - p)
Zygote.gradient(f,Float32[1.0,1.0])
So the issue is just the literal Float32[1.0,0.0]
in the loss function. Write your loss as:
my_data = Float32[1.0,0.0]
function loss(p)
pred = model(u0,p)
loss = sum(abs2, my_data .- pred)
loss,pred
end
and you should be good. Looks like another case of https://github.com/FluxML/Zygote.jl/issues/513, so we'll get this fixed up in Zygote!
Thanks a lot Chris!! :)
Actually it was still throwing an error when the model contained the ode_layer
but I realized the problem was because of applying last(...)
on the concrete_solve(...)
output. For some reason, Zygote doesn't seem to like function last
. Using Array(...)[:,end]
solved the issue!
Can you try and make an MWE for that?
using Zygote
f(p) = last(p)
Zygote.gradient(f,Float32[1.0,1.0])
this works, so I can't recreate this issue.
I think that the problem is not with last
alone but with the usage of last
together with saveat
in concrete_solve:
using Zygote, OrdinaryDiffEq
function lotka_volterra(u,p,t)
x, y = u
α, β, δ, γ = p
dx = α*x - β*x*y
dy = -δ*y + γ*x*y
return [dx, dy]
end
u0 = Float32[1.0,1.0]
tspan = (0.0f0,10.0f0)
p = Float32[1.5,1.0,3.0,1.0]
function f(p)
prob = ODEProblem(lotka_volterra,u0,tspan,p)
sum(last(concrete_solve(prob,Tsit5(),u0,p,saveat=tspan[2])))
end
Zygote.gradient(f,p)
This breaks, but the following works
function f(p)
prob = ODEProblem(lotka_volterra,u0,tspan,p)
sum(last(concrete_solve(prob,Tsit5(),u0,p)))
end
Zygote.gradient(f,p)
even though
prob = ODEProblem(lotka_volterra,u0,tspan,p)
last(concrete_solve(prob,Tsit5(),u0,p,saveat=tspan[2])) == last(concrete_solve(prob,Tsit5(),u0,p)) # true
Turned it into an MWE here: https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/160. Thanks!
Thanks to you!
First of all, thanks a lot for this amazing package :)
I'm trying to implement a model with usual Flux layers chained with a differential equation layer. It would be like a NeuralODE in between regular nn layers, but instead of letting the NeuralODE approximate the DE, I want to give it some structure. I'm not managing to get it to work even with the simplest case. This is a sample code:
I get the following error:
Using BFGS gives a similar error. I don't find which mutating array it is complaining about.
Using
sensealg = ForwardDiffSensitivity()
in the concrete_solve leads to a different error:I would truly appreciate it if you could, please, shed some light on this. What am I doing wrong? Do you have in mind another simpler way of implementing the DE layer in between regular Flux layers and make it all work?
I'm using Julia 1.3.1 and packages version: [aae7a2af] DiffEqFlux v1.3.1 [41bf760c] DiffEqSensitivity v6.6.1 [0c46a032] DifferentialEquations v6.10.1 [587475ba] Flux v0.10.1 [429524aa] Optim v0.20.1 [1dea7af3] OrdinaryDiffEq v5.28.1 [9f7883ad] Tracker v0.2.6