Please help me to understand the cause of the error when running the DEQ example from Julia's blog (Deep Equilibrium Models)
this code
using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using OrdinaryDiffEq
#using CUDA
using Plots
using LinearAlgebra
#CUDA.allowscalar(false)
struct DeepEquilibriumNetwork{M,P,RE,A,K}
model::M
p::P
re::RE
args::A
kwargs::K
end
Flux.@functor DeepEquilibriumNetwork
function DeepEquilibriumNetwork(model, args...; kwargs...)
p, re = Flux.destructure(model)
return DeepEquilibriumNetwork(model, p, re, args, kwargs)
end
Flux.trainable(deq::DeepEquilibriumNetwork) = (deq.p,)
function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T}, p = deq.p) where {T}
z = deq.re(p)(x)
# Solving the equation f(u) - u = du = 0
# The key part of DEQ is similar to that of NeuralODEs
dudt(u, _p, t) = deq.re(_p)(u .+ x) .- u
ssprob = SteadyStateProblem(ODEProblem(dudt, z, (zero(T), one(T)), p))
return solve(ssprob, deq.args...; u0 = z, deq.kwargs...).u
end
ann = Chain(Dense(1, 5), Dense(5, 1))
deq = DeepEquilibriumNetwork(ann, DynamicSS(Tsit5(), abstol = 1.0f-2, reltol = 1.0f-2))
# Let's run a DEQ model on linear regression for y = 2x
X = reshape(Float32[1; 2; 3; 4; 5; 6; 7; 8; 9; 10], 1, :)
Y = 2 .* X
opt = ADAM(0.05)
loss(x, y) = sum(abs2, y .- deq(x))
Flux.train!(loss, Flux.params(deq), ((X, Y),), opt)
throws the following error on line (JuliaFlux.train!(loss, Flux.params(deq), ((X, Y),), opt))
Please help me to understand the cause of the error when running the DEQ example from Julia's blog (Deep Equilibrium Models)
this code
throws the following error on line (
JuliaFlux.train!(loss, Flux.params(deq), ((X, Y),), opt)
)Operating System: Windows 10 Julia 1.6.5 VScode 1.63.2 Pkg.status