SciML / DeepEquilibriumNetworks.jl

Implicit Layer Machine Learning via Deep Equilibrium Networks, O(1) backpropagation with accelerated convergence.
https://docs.sciml.ai/DeepEquilibriumNetworks/stable/
MIT License
50 stars 5 forks source link

TypeErro in DEQ example: non-boolean (Nothing) used in boolean context #36

Closed yadmtr closed 2 years ago

yadmtr commented 2 years ago

Please help me to understand the cause of the error when running the DEQ example from Julia's blog (Deep Equilibrium Models)

this code

using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using OrdinaryDiffEq
#using CUDA
using Plots
using LinearAlgebra
#CUDA.allowscalar(false)

struct DeepEquilibriumNetwork{M,P,RE,A,K}
    model::M
    p::P
    re::RE
    args::A
    kwargs::K
end

Flux.@functor DeepEquilibriumNetwork

function DeepEquilibriumNetwork(model, args...; kwargs...)
    p, re = Flux.destructure(model)
    return DeepEquilibriumNetwork(model, p, re, args, kwargs)
end

Flux.trainable(deq::DeepEquilibriumNetwork) = (deq.p,)

function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T}, p = deq.p) where {T}
    z = deq.re(p)(x)
    # Solving the equation f(u) - u = du = 0
    # The key part of DEQ is similar to that of NeuralODEs
    dudt(u, _p, t) = deq.re(_p)(u .+ x) .- u
    ssprob = SteadyStateProblem(ODEProblem(dudt, z, (zero(T), one(T)), p))
    return solve(ssprob, deq.args...; u0 = z, deq.kwargs...).u
end

ann = Chain(Dense(1, 5), Dense(5, 1))

deq = DeepEquilibriumNetwork(ann, DynamicSS(Tsit5(), abstol = 1.0f-2, reltol = 1.0f-2))

# Let's run a DEQ model on linear regression for y = 2x
X = reshape(Float32[1; 2; 3; 4; 5; 6; 7; 8; 9; 10], 1, :) 
Y = 2 .* X
opt = ADAM(0.05)

loss(x, y) = sum(abs2, y .- deq(x))

Flux.train!(loss, Flux.params(deq), ((X, Y),), opt)

throws the following error on line (JuliaFlux.train!(loss, Flux.params(deq), ((X, Y),), opt))

ERROR: LoadError: TypeError: non-boolean (Nothing) used in boolean context
Stacktrace:
  [1] _concrete_solve_adjoint(::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}, ::Nothing, ::Matrix{Float32}, ::Vector{Float32}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity C:\Users\D\.julia\packages\DiffEqSensitivity\Kg0cc\src\concrete_solve.jl:92
  [2] _concrete_solve_adjoint
    @ C:\Users\D\.julia\packages\DiffEqSensitivity\Kg0cc\src\concrete_solve.jl:72 [inlined]
  [3] #_solve_adjoint#56
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:347 [inlined]
  [4] _solve_adjoint
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:322 [inlined]
  [5] #rrule#54
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:310 [inlined]
  [6] rrule
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:310 [inlined]
  [7] rrule
    @ C:\Users\D\.julia\packages\ChainRulesCore\oBjCg\src\rules.jl:134 [inlined]
  [8] chain_rrule
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\chainrules.jl:216 [inlined]
  [9] macro expansion
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(DiffEqBase.solve_up), ::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Nothing, ::Matrix{Float32}, ::Vector{Float32}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:9
 [11] _apply
    @ .\boot.jl:804 [inlined]
 [12] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [13] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [14] _pullback
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:73 [inlined]
 [15] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#38", ::Nothing, ::Matrix{Float32}, ::Nothing, ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(solve), ::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, 
NamedTuple{(), Tuple{}}}}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [16] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core .\boot.jl:804
 [17] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [18] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [19] _pullback
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:68 [inlined]
 [20] _pullback(::Zygote.Context, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:u0,), Tuple{Matrix{Float32}}}, ::typeof(solve), ::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [21] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core .\boot.jl:804
 [22] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [23] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [24] _pullback
    @ c:\Users\D\w7d\test_flux_e[ample.jl:33 [inlined]
 [25] _pullback(::Zygote.Context, ::DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Matrix{Float32}, 
::Vector{Float32})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [26] _pullback
    @ c:\Users\D\w7d\test_flux_e[ample.jl:28 [inlined]
 [27] _pullback(ctx::Zygote.Context, f::DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, args::Matrix{Float32})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [28] _pullback
    @ c:\Users\D\w7d\test_flux_e[ample.jl:45 [inlined]
 [29] _pullback(::Zygote.Context, ::typeof(loss), ::Matrix{Float32}, ::Matrix{Float32})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [30] _apply
    @ .\boot.jl:804 [inlined]
 [31] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [32] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [33] _pullback
    @ C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:105 [inlined]
 [34] _pullback(::Zygote.Context, ::Flux.Optimise.var"#39#45"{typeof(loss), Tuple{Matrix{Float32}, Matrix{Float32}}})   
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [35] pullback(f::Function, ps::Zygote.Params)
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface.jl:352
 [36] gradient(f::Function, args::Zygote.Params)
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface.jl:75
 [37] macro expansion
    @ C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:104 [inlined]
 [38] macro expansion
    @ C:\Users\D\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [39] train!(loss::Function, ps::Zygote.Params, data::Tuple{Tuple{Matrix{Float32}, Matrix{Float32}}}, opt::ADAM; cb::Flux.Optimise.var"#40#46")
    @ Flux.Optimise C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:102
 [40] train!(loss::Function, ps::Zygote.Params, data::Tuple{Tuple{Matrix{Float32}, Matrix{Float32}}}, opt::ADAM)
    @ Flux.Optimise C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:100
 [41] top-level scope
    @ c:\Users\D\w7d\test_flux_e[ample.jl:47
in expression starting at c:\Users\D\w7d\test_flux_e[ample.jl:47

Operating System: Windows 10 Julia 1.6.5 VScode 1.63.2 Pkg.status

  [052768ef] CUDA v3.6.4
  [31a5f54b] Debugger v0.7.0
  [2b5f629d] DiffEqBase v6.81.0
  [41bf760c] DiffEqSensitivity v6.68.0
  [587475ba] Flux v0.12.8
  [5903a43b] Infiltrator v1.1.2
  [98e50ef6] JuliaFormatter v0.21.2
  [aa1ae85d] JuliaInterpreter v0.9.1
  [eb30cadb] MLDatasets v0.5.14
  [2774e3e8] NLsolve v4.5.1
  [1dea7af3] OrdinaryDiffEq v6.4.2
  [91a5bcdd] Plots v1.25.6
  [ee283ea6] Rebugger v0.2.2
  [9672c7b4] SteadyStateDiffEq v1.6.6
  [c3572dad] Sundials v4.9.1
  [e88e6eb3] Zygote v0.6.33
  [37e2e46d] LinearAlgebra
  [8dfed614] Test
ChrisRackauckas commented 2 years ago

@avik-pal

ChrisRackauckas commented 2 years ago

This works now, and for DEQs we have a whole library https://github.com/SciML/FastDEQ.jl