FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
72 stars 20 forks source link

Type instability in `Flux.setup` #162

Open Vilin97 opened 11 months ago

Vilin97 commented 11 months ago
using Flux

function test_setup(opt, s)
    state = Flux.setup(opt, s)
    return state
end
s = Chain(
        Dense(2 => 100, softsign),
        Dense(100 => 2)
    )
opt = Adam(0.1)
@code_warntype test_setup(opt, s) # type unstable

Output:

MethodInstance for GradientFlows.test_setup(::Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from test_setup(opt, s) @ GradientFlows c:\Users\Math User\.julia\dev\GradientFlows\src\solvers\sbtm.jl:106
Arguments
  #self#::Core.Const(GradientFlows.test_setup)
  opt::Adam
  s::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  state::Any
Body::Any
1 ─ %1 = Flux.setup::Core.Const(Flux.Train.setup)
│        (state = (%1)(opt, s))
└──      return state

Julia version 1.9.3 and Flux version 0.14.6:

(@v1.9) pkg> st Flux
Status `C:\Users\Math User\.julia\environments\v1.9\Project.toml`
  [587475ba] Flux v0.14.6
ToucheSir commented 11 months ago

setup is defined in Optimisers.jl, and it's inherently type unstable because it uses a cache to detect + handle shared parameters. Usually I would mark this as a WONTFIX, but there might be some fancy method and/or newer version of Julia which lets us make setup more type stable.

mcabbott commented 11 months ago

Values from the cache are used when an object x is === some previously seen x. They should therefore always have the same type as what init(rule, x) returns. If this type can be inferred, probably we tell the compiler what to expect, and this may make the whole setup type-stable? Haven't tried though.

ToucheSir commented 11 months ago

We could use _return_type or friends to do that, yes. One thing I'd like to try to make that easier is to delegate what Functors.CachedWalk currently does to the callback passed into the maps. Then it should be easier to swap in/out different implementations of caching and memoization by simply switching the callback.

mcabbott commented 11 months ago
function _setup(rule, x; cache)
  if haskey(cache, x)
    T1 = Base._return_type(init, Tuple{typeof(rule), typeof(x)})
    T2 = Base._return_type(Leaf, Tuple{typeof(rule), T1})
    return cache[x]::T2
  end
  if isnumeric(x)
    ℓ = Leaf(rule, init(rule, x))
    # as before...

gives

julia> @code_warntype test_setup(opt, s)
MethodInstance for test_setup(::Optimisers.Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from test_setup(opt, s) @ Main REPL[5]:1
Arguments
  #self#::Core.Const(test_setup)
  opt::Optimisers.Adam
  s::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  state::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
Body::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
1 ─ %1 = Flux.setup::Core.Const(Flux.Train.setup)
│        (state = (%1)(opt, s))
└──      return state

julia> @code_warntype Optimisers.setup(opt, s)
MethodInstance for Optimisers.setup(::Optimisers.Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from setup(rule::AbstractRule, model) @ Optimisers ~/.julia/dev/Optimisers/src/interface.jl:29
Arguments
  #self#::Core.Const(Optimisers.setup)
  rule::Optimisers.Adam
  model::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  tree::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
  cache::IdDict{Any, Any}
  msg::String
  kwargs::@NamedTuple{}
  line::Int64
  file::String
  id::Symbol
  logger::Union{Nothing, Base.CoreLogging.AbstractLogger}
  _module::Module
  group::Symbol
  std_level::Base.CoreLogging.LogLevel
  level::Base.CoreLogging.LogLevel
Body::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
1 ──       (cache = Optimisers.IdDict())
│    %2  = (:cache,)::Core.Const((:cache,))
│    %3  = Core.apply_type(Core.NamedTuple, %2)::Core.Const(NamedTuple{(:cache,)})
...
ToucheSir commented 11 months ago

Looks like the inference path _return_type uses might not able to work through the recursion? I wonder if we could use a trick like https://github.com/FluxML/Functors.jl/pull/61 to prevent it from bailing.

Vilin97 commented 10 months ago

In the meantime, would it make sense to add a sentence like This function is type-unstable. to the docstring of setup? If I had seen such a sentence in the docstring, it would have saved me a lot of trouble of discovering it for myself.

mcabbott commented 5 months ago

would it make sense to add a sentence like "This function is type-unstable." to the docstring of setup?

Yes, probably.

Also to emphasise that the way to deal with this is a function barrier. You run setup exactly once & pass its result to something. If you are running it in a tight loop, you are probably doing it wrong.