FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.54k stars 608 forks source link

TypeErro in DEQ example: non-boolean (Nothing) used in boolean context #677 #1846

Closed yadmtr closed 2 years ago

yadmtr commented 2 years ago

Please help me to understand the cause of the error when running the DEQ example from Julia's blog (Deep Equilibrium Models)

this code

using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using OrdinaryDiffEq
#using CUDA
using Plots
using LinearAlgebra
#CUDA.allowscalar(false)

struct DeepEquilibriumNetwork{M,P,RE,A,K}
    model::M
    p::P
    re::RE
    args::A
    kwargs::K
end

Flux.@functor DeepEquilibriumNetwork

function DeepEquilibriumNetwork(model, args...; kwargs...)
    p, re = Flux.destructure(model)
    return DeepEquilibriumNetwork(model, p, re, args, kwargs)
end

Flux.trainable(deq::DeepEquilibriumNetwork) = (deq.p,)

function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T}, p = deq.p) where {T}
    z = deq.re(p)(x)
    # Solving the equation f(u) - u = du = 0
    # The key part of DEQ is similar to that of NeuralODEs
    dudt(u, _p, t) = deq.re(_p)(u .+ x) .- u
    ssprob = SteadyStateProblem(ODEProblem(dudt, z, (zero(T), one(T)), p))
    return solve(ssprob, deq.args...; u0 = z, deq.kwargs...).u
end

ann = Chain(Dense(1, 5), Dense(5, 1))

deq = DeepEquilibriumNetwork(ann, DynamicSS(Tsit5(), abstol = 1.0f-2, reltol = 1.0f-2))

# Let's run a DEQ model on linear regression for y = 2x
X = reshape(Float32[1; 2; 3; 4; 5; 6; 7; 8; 9; 10], 1, :) 
Y = 2 .* X
opt = ADAM(0.05)

loss(x, y) = sum(abs2, y .- deq(x))

Flux.train!(loss, Flux.params(deq), ((X, Y),), opt)

throws the following error on line (JuliaFlux.train!(loss, Flux.params(deq), ((X, Y),), opt))

ERROR: LoadError: TypeError: non-boolean (Nothing) used in boolean context
Stacktrace:
  [1] _concrete_solve_adjoint(::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}, ::Nothing, ::Matrix{Float32}, ::Vector{Float32}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity C:\Users\D\.julia\packages\DiffEqSensitivity\Kg0cc\src\concrete_solve.jl:92
  [2] _concrete_solve_adjoint
    @ C:\Users\D\.julia\packages\DiffEqSensitivity\Kg0cc\src\concrete_solve.jl:72 [inlined]
  [3] #_solve_adjoint#56
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:347 [inlined]
  [4] _solve_adjoint
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:322 [inlined]
  [5] #rrule#54
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:310 [inlined]
  [6] rrule
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:310 [inlined]
  [7] rrule
    @ C:\Users\D\.julia\packages\ChainRulesCore\oBjCg\src\rules.jl:134 [inlined]
  [8] chain_rrule
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\chainrules.jl:216 [inlined]
  [9] macro expansion
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(DiffEqBase.solve_up), ::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Nothing, ::Matrix{Float32}, ::Vector{Float32}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:9
 [11] _apply
    @ .\boot.jl:804 [inlined]
 [12] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [13] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [14] _pullback
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:73 [inlined]
 [15] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#38", ::Nothing, ::Matrix{Float32}, ::Nothing, ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(solve), ::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, 
NamedTuple{(), Tuple{}}}}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [16] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core .\boot.jl:804
 [17] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [18] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [19] _pullback
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:68 [inlined]
 [20] _pullback(::Zygote.Context, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:u0,), Tuple{Matrix{Float32}}}, ::typeof(solve), ::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [21] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core .\boot.jl:804
 [22] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [23] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [24] _pullback
    @ c:\Users\D\w7d\test_flux_e[ample.jl:33 [inlined]
 [25] _pullback(::Zygote.Context, ::DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Matrix{Float32}, 
::Vector{Float32})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [26] _pullback
    @ c:\Users\D\w7d\test_flux_e[ample.jl:28 [inlined]
 [27] _pullback(ctx::Zygote.Context, f::DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, args::Matrix{Float32})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [28] _pullback
    @ c:\Users\D\w7d\test_flux_e[ample.jl:45 [inlined]
 [29] _pullback(::Zygote.Context, ::typeof(loss), ::Matrix{Float32}, ::Matrix{Float32})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [30] _apply
    @ .\boot.jl:804 [inlined]
 [31] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [32] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [33] _pullback
    @ C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:105 [inlined]
 [34] _pullback(::Zygote.Context, ::Flux.Optimise.var"#39#45"{typeof(loss), Tuple{Matrix{Float32}, Matrix{Float32}}})   
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [35] pullback(f::Function, ps::Zygote.Params)
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface.jl:352
 [36] gradient(f::Function, args::Zygote.Params)
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface.jl:75
 [37] macro expansion
    @ C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:104 [inlined]
 [38] macro expansion
    @ C:\Users\D\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [39] train!(loss::Function, ps::Zygote.Params, data::Tuple{Tuple{Matrix{Float32}, Matrix{Float32}}}, opt::ADAM; cb::Flux.Optimise.var"#40#46")
    @ Flux.Optimise C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:102
 [40] train!(loss::Function, ps::Zygote.Params, data::Tuple{Tuple{Matrix{Float32}, Matrix{Float32}}}, opt::ADAM)
    @ Flux.Optimise C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:100
 [41] top-level scope
    @ c:\Users\D\w7d\test_flux_e[ample.jl:47
in expression starting at c:\Users\D\w7d\test_flux_e[ample.jl:47

Operating System: Windows 10 Julia 1.6.5 VScode 1.63.2 Pkg.status

  [052768ef] CUDA v3.6.4
  [31a5f54b] Debugger v0.7.0
  [2b5f629d] DiffEqBase v6.81.0
  [41bf760c] DiffEqSensitivity v6.68.0
  [587475ba] Flux v0.12.8
  [5903a43b] Infiltrator v1.1.2
  [98e50ef6] JuliaFormatter v0.21.2
  [aa1ae85d] JuliaInterpreter v0.9.1
  [eb30cadb] MLDatasets v0.5.14
  [2774e3e8] NLsolve v4.5.1
  [1dea7af3] OrdinaryDiffEq v6.4.2
  [91a5bcdd] Plots v1.25.6
  [ee283ea6] Rebugger v0.2.2
  [9672c7b4] SteadyStateDiffEq v1.6.6
  [c3572dad] Sundials v4.9.1
  [e88e6eb3] Zygote v0.6.33
  [37e2e46d] LinearAlgebra
  [8dfed614] Test
mcabbott commented 2 years ago

This line https://github.com/SciML/DiffEqSensitivity.jl/blob/70f61d8a69d8cfd05beb257b93fceef9462c5f91/src/concrete_solve.jl#L72 seems to call isgpu defined here https://github.com/SciML/DiffEqSensitivity.jl/blob/3a6cad542aa8143a76e2a6e928ad90e98f361a55/src/require.jl#L1 which looks pretty boolean.

If I run the example locally, I get a different line number: In fact it's line 92 above, too.

julia> Flux.train!(loss, Flux.params(deq), ((X, Y),), opt)y
ERROR: TypeError: non-boolean (Nothing) used in boolean context
Stacktrace:
  [1] _concrete_solve_adjoint(::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#8"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#64#66"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}, ::Nothing, ::Matrix{Float32}, ::Vector{Float32}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/Kg0cc/src/concrete_solve.jl:92
  [2] _concrete_solve_adjoint
    @ ~/.julia/packages/DiffEqSensitivity/Kg0cc/src/concrete_solve.jl:72 [inlined]

which is this https://github.com/SciML/DiffEqSensitivity.jl/blob/70f61d8a69d8cfd05beb257b93fceef9462c5f91/src/concrete_solve.jl#L92 , and ez is defined just above by an if statement that can return nothing:

julia> ez = if false
         try
                 Enzyme.autodiff(Enzyme.Duplicated(du, du),
                                 u0,p,prob.tspan[1]) do out,u,_p,t
                   f(out, u, _p, t)
                   nothing
                 end
                 true
               catch
                 false
               end
           end

julia> ez === nothing
true

So that's the issue, I presume.

ChrisRackauckas commented 2 years ago

Thanks for the report. This was fixed in https://github.com/SciML/DiffEqSensitivity.jl/pull/551 with a bunch of new tests.