Closed John-Boik closed 2 years ago
Doesn't a GRU
have state though, so the model wouldn't be well-defined?
I think @avik-pal and @DhairyaLGandhi have mentioned something about destructure
giving different arguments out when a layer has state, which is a bit weird. Could one of you give some input on that? I think that would be someone to fix up on the Flux side, even if it would be a breaking change, making the outputs out of that function uniform and documented would fix issues like this.
If it's of help, learning-long-term-irregular-ts shows (starting on line 566) code for the GRUODE written in python.
Note that method is only going to be compatible with adaptive=false
because otherwise the state makes the ODE undefined. I think all you need is to turn off adaptivity and whatever that different destructure
is.
@avik-pal and @DhairyaLGandhi, note that the solve function runs properly and produces anticipated output. The DimensionMismatch error occurs later when gradients are taken. Also, the same error occurs when using adaptive=false: OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, adaptive=false, dt=.5)
, instead of the solve()
in the code above.
The destructure
issue Chris mentioned above should not lead to dimension mismatch error. It just makes the GRU work without any recurrence, as the state is overwritten every time we do re(p)
. (@DhairyaLGandhi do you know how to fix this?)
The exact source of the error you encounter seems to be the sensitivity algorithm. A quick fix would be:
res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, sensealg=InterpolatingAdjoint(autojacvec=false)))
We don't close over a number of arguments in destructure
, which may be necessary for our restructure
case as well. Adding those back to our cache, which can be passed around to the restrcture could do it.
function destructure(m; cache = IdDict())
xs = Zygote.Buffer([])
fmap(m) do x
if x isa AbstractArray
push!(xs, x)
else
cache[x] = x
end
return x
end
return vcat(vec.(copy(xs))...), p -> _restructure(m, p, cache = cache)
end
function _restructure(m, XS; cache = IdDict())
i = 0
fmap(m) do x
x isa AbstractArray || return cache[x]
x = reshape(xs[i.+(1:length(x))], size(x))
i += length(x)
return x
end
end
This is untested currently, @avik-pal would something like this solve the specific issue you're talking about?
Thanks. I can verify that the following works with Flux.GRU. I used autojacvec=true
rather than false
, and it seems to run a bit faster that way.
module TestDiffeq3bb
using Revise
using Infiltrator
using Formatting
import DiffEqFlux
import OrdinaryDiffEq
import DiffEqSensitivity
import Flux
import Optim
import Plots
import Zygote
import Functors
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
iter = 0
function trueODEfunc(du, u, p, t)
true_A = Float32[-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
prob_trueode = OrdinaryDiffEq.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(OrdinaryDiffEq.solve(prob_trueode, OrdinaryDiffEq.Tsit5(), saveat = tsteps))
dudt2 = Flux.Chain(
x -> x.^3,
Flux.Dense(2, 20, tanh),
Flux.GRU(20, 20),
Flux.Dense(20, 2),
)
function destructure(m; cache = IdDict())
xs = Zygote.Buffer([])
Functors.fmap(m) do x
if x isa AbstractArray
push!(xs, x)
else
cache[x] = x
end
return x
end
return vcat(vec.(copy(xs))...), p -> _restructure(m, p, cache = cache)
end
function _restructure(m, xs; cache = IdDict())
i = 0
Functors.fmap(m) do x
x isa AbstractArray || return cache[x]
x = reshape(xs[i .+ (1:length(x))], size(x))
i += length(x)
return x
end
end
p, re = destructure(dudt2)
function neural_ode_f(u, p, t)
return re(p)(u)
end
prob = OrdinaryDiffEq.ODEProblem(neural_ode_f, u0, tspan, p)
function predict_neuralode(p)
tmp_prob = OrdinaryDiffEq.remake(prob,p=p)
res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, sensealg=DiffEqSensitivity.InterpolatingAdjoint(autojacvec=true)))
return res
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
callback = function (p, l, pred; doplot = true)
global iter
iter += 1
@show iter, l
# plot current prediction against data
plt = Plots.scatter(tsteps, ode_data[1,:], label = "data", title=string(iter))
Plots.scatter!(plt, tsteps, pred[1,:], label = "prediction")
if doplot
display(Plots.plot(plt))
end
return false
end
result_neuralode = DiffEqFlux.sciml_train(
loss_neuralode,
p,
Flux.ADAM(0.05),
cb = callback,
maxiters = 60
)
end # ------------------------------- module -----------------------------------
The fixed restructure/destructure works:
import DiffEqFlux
import OrdinaryDiffEq
import Flux
import Optim
import Plots
import Zygote
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
prob_trueode = OrdinaryDiffEq.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(OrdinaryDiffEq.solve(prob_trueode, OrdinaryDiffEq.Tsit5(), saveat = tsteps))
dudt2 = Flux.Chain(
x -> x.^3,
Flux.Dense(2, 50, tanh),
#Flux.Dense(50, 2)
Flux.GRU(50, 2)
)
sf
function destructure(m; cache = IdDict())
xs = Zygote.Buffer([])
Flux.fmap(m) do x
if x isa AbstractArray
push!(xs, x)
else
cache[x] = x
end
return x
end
return vcat(vec.(copy(xs))...), p -> _restructure(m, p, cache = cache)
end
function _restructure(m, xs; cache = IdDict())
i = 0
Flux.fmap(m) do x
x isa AbstractArray || return cache[x]
x = reshape(xs[i.+(1:length(x))], size(x))
i += length(x)
return x
end
end
p, re = destructure(dudt2)
neural_ode_f(u, p, t) = re(p)(u)
prob = OrdinaryDiffEq.ODEProblem(neural_ode_f, u0, tspan, p)
function predict_neuralode(p)
tmp_prob = OrdinaryDiffEq.remake(prob,p=p)
res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, dt=0.01, adaptive=false))
return res
end
function loss_neuralode(p)
pred = predict_neuralode(p) # (2,30)
loss = sum(abs2, ode_data .- pred) # scalar
return loss, pred
end
callback = function (p, l, pred; doplot = true)
display(l)
# plot current prediction against data
plt = Plots.scatter(tsteps, ode_data[1,:], label = "data")
Plots.scatter!(plt, tsteps, pred[1,:], label = "prediction")
if doplot
display(Plots.plot(plt))
end
return false
end
result_neuralode = DiffEqFlux.sciml_train(
loss_neuralode,
p,
Flux.ADAM(0.05),
cb = callback,
maxiters = 3000
)
result_neuralode2 = DiffEqFlux.sciml_train(
loss_neuralode,
result_neuralode.minimizer,
Flux.ADAM(0.05),
cb = callback,
maxiters = 1000,
)
The method isn't very good, but it does what you asked for.
Excellent @ChrisRackauckas. I see that sensealg=DiffEqSensitivity.InterpolatingAdjoint()
is not needed, and dt=0.01, adaptive=false
can be used.
The method is a means to an end (eventually, GRU-ODE), but it does work reasonably well as is if Flux.GRU(50, 2)
is replaced by Flux.GRU(50, 50), Flux.Dense(50, 2)
. At 200 iterations of just SGD, the error is about 0.18. If I switch in my custom GRU (below, as per the GRU-ODE paper), and reduce the hidden layer size to 20 from 50, the error is about 0.02 at 200 iterations. Both are less than the error of about 0.48 achieved with a chain of Flux.Dense(2, 50, tanh), Flux.Dense(50, 2)
.
The custom GRU is as follows. In this problem, I use Flux2.GRU2(20, true, x -> x )
when defining the chain.
module Flux2
import Flux
using Infiltrator
mutable struct GRUCell2{W,U,B,L,TF,F}
update_W::W
update_U::U
update_b::B
reset_W::W
reset_U::U
reset_b::B
out_W::W
out_U::U
out_b::B
H::L
is_dhdt::TF
fx::F
end
GRUCell2(L, is_output_dhdt, fx; init = Flux.glorot_uniform) =
GRUCell2(
init(L, L),
init(L, L),
init(L,1),
init(L, L),
init(L, L),
init(L,1),
init(L, L),
init(L, L),
init(L,1),
zeros(Float32, (L,1)),
is_output_dhdt,
fx
)
function (m::GRUCell2)(H, X)
update_gate = Flux.sigmoid.(
(m.update_W * X)
.+ (m.update_U * m.H)
.+ m.update_b)
reset_gate = Flux.sigmoid.(
(m.reset_W * X)
.+ (m.reset_U * m.H)
.+ m.reset_b)
output_gate = m.fx.(
(m.out_W * X)
.+ (m.out_U * (reset_gate .* m.H))
.+ m.out_b)
if m.is_dhdt == true
# output is dhdt
output = (Float32(1) .- update_gate) .* (output_gate .- H)
else
# standard GRU output
output = ((Float32(1) .- update_gate) .* output_gate) .+ (update_gate .* H)
end
H = output
return H, H
end
Flux.hidden(m::GRUCell2) = m.H
Flux.@functor GRUCell2
Base.show(io::IO, l::GRUCell2) =
print(io, "GRUCell2(", size(l.update_W, 2), ", ", size(l.update_W, 1), ")")
GRU2(a...; ka...) = Flux.Recur(GRUCell2(a...; ka...))
end # ----------------------------- module
Cool yeah. The other thing to try is sensealg=ReverseDiffAdjoint()
. Using direct reverse-mode AD might be better if it's fixed time step since that would not have the same possibility of having adjoint error like the continuous adjoints, which would be more of an issue if it's not adaptive on the reverse.
It would be good to turn this into a tutorial when all is said and done. @DhairyaLGandhi could you add that restructure/destructure patch to Flux and then tag a release? @John-Boik would you be willing to contribute a tutorial?
Or @mkg33 might be able to help out here.
Sure, I would be happy to help if I can.
Of course, I'll add it to my tasks.
Has this "fix" been released to FluxML yet?
ODE-LSTM implementations @John-Boik, Are you or others still working on ODE-LSTM implementation in FluxML?
I'm working on similar models, which also use re/de structure. The fix to re/de structure has been released, and both functions are working fine as far as I know.
I think @DhairyaLGandhi didn't merge the fix yet https://github.com/FluxML/Flux.jl/pull/1353
It's still a bad model though.
This was fixed by https://github.com/FluxML/Flux.jl/pull/1901, and one can now use Lux which makes the state explicit. Cheers!
As a first step leading up to GRU-ODE or ODE-LSTM implementations, I'd like to switch out the Dense layer in the neural_ode_sciml example with a GRU layer. However, doing so raises the error
LoadError: DimensionMismatch("array could not be broadcast to match destination")
. I don't understand where the problem is occuring, exactly, or how to fix it. Any ideas?Code is as follows, with the main differences from the original example being:
using
statements have been changed toimport
statements (for clarity)FastChain
has been changed toChain
include("./TestDiffEq3b.jl")
Dense
layer has been changed to aGRU
layerThis issue is loosely related to Training of UDEs with recurrent networks #391 and Flux.destructure doesn't preserve RNN state #1329. See also ODE-LSTM layer #422 .
The code is as follows, with the Dense layer commented out and replaced by the GRU layer:
The error message is: