FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.48k stars 604 forks source link

Why is Flux.destructure type unstable? #2405

Closed irisallevi closed 6 months ago

irisallevi commented 6 months ago

I was building a simple model and at some point I needed to "unroll" it to get all the parameters in an array.

So I tired with Flux.destructure. I got some type instability, so I checked the documentation and I tried with the example provided there:

model = Chain(Dense(2 => 1, tanh), Dense(1 => 1))
@code_warntype Flux.destructure(model)

But this gives a type instability as well!

Flux.destructure(model) = (Float32[0.27410066, 0.6508191, 0.0, 0.16767712, 0.0], Restructure(Chain, ..., 5))
MethodInstance for Optimisers.destructure(::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from destructure(x) @ Optimisers
Arguments
  #self#::Core.Const(Optimisers.destructure)
  x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  @_3::Int64
  len::Int64
  off::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}
  flat::AbstractVector
Body::Tuple{AbstractVector, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})}
1 ─ %1  = Optimisers._flatten(x)::Tuple{AbstractVector, NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Int64}
│   %2  = Base.indexed_iterate(%1, 1)::Core.PartialStruct(Tuple{AbstractVector, Int64}, Any[AbstractVector, Core.Const(2)])
│         (flat = Core.getfield(%2, 1))
│         (@_3 = Core.getfield(%2, 2))
│   %5  = Base.indexed_iterate(%1, 2, @_3::Core.Const(2))::Core.PartialStruct(Tuple{NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Int64}, Any[NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Core.Const(3)])
│         (off = Core.getfield(%5, 1))
│         (@_3 = Core.getfield(%5, 2))
│   %8  = Base.indexed_iterate(%1, 3, @_3::Core.Const(3))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(4)])
│         (len = Core.getfield(%8, 1))
│   %10 = flat::AbstractVector
│   %11 = Optimisers.Restructure(x, off, len)::Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})
│   %12 = Core.tuple(%10, %11)::Tuple{AbstractVector, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})}
└──       return %12

What am I missing?

mcabbott commented 6 months ago

Flux uses Functors.jl for all kinds of recursive walks, and when there are mutable objects, it keeps a cache of their objectIDs to look for duplicates. This means that it branches on the values of objectid, and this is usually type-unstable. You can see it for instance in @code_warntype f64(model).

The reason it does this is to allow for shared parameters. The same array may appear multiple times. To be honest this is a giant pain, maybe it's a feature not worth preserving? Forbidding it would simplify many things. It is turned off for e.g. models with SMatrix parameters (which have no identitty beyond their value).

Type stability is super-important deep inside tight loops, and hence drummed into us when learning Julia, but often doesn't matter at all for larger objects. E.g. removing type parameters from Flux layers often has no impact on performance, as there are enough function barriers between there and operations which take all the time.

Having said all that, I'm not 100% sure this isn't an XY problem. Is your question actually why it is unstable, or are you really implying that you believe this is the cause of a performance problem?

mcabbott commented 6 months ago

Same Q on discourse here, please link things so as not to waste time on duplicates.

Here's a quick example to show some design differences between ComponentArrays and Optimisers.destructure:

julia> using ComponentArrays, Optimisers

julia> arr = [1.0, 2.0];

julia> v, re = Optimisers.destructure((one=arr, two=[3f0], three=arr))  # this notices & preserves x1 === x3
([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3))

julia> v .= 99;  # this does not mutate arr, v is a copy

julia> nt = re([10, 20, 30.0])
(one = [10.0, 20.0], two = Float32[30.0], three = [10.0, 20.0])

julia> nt.one === nt.three  # identity is restored
true

julia> ca = ComponentArray(one=arr, two=[3f0], three=arr)  # this ignores the identity
ComponentVector{Float64}(one = [1.0, 2.0], two = [3.0], three = [1.0, 2.0])

julia> getfield(ca, :data)
5-element Vector{Float64}:
 1.0
 2.0
 3.0
 1.0
 2.0

julia> ca.two  # type has been promoted on construction
1-element view(::Vector{Float64}, 3:3) with eltype Float64:
 3.0

julia> ca.three .= 99;  # structured form is a view of flat form

julia> ca
ComponentVector{Float64}(one = [1.0, 2.0], two = [3.0], three = [99.0, 99.0])
irisallevi commented 6 months ago

Thank you @mcabbott and sorry for not linking. These component arrays seem very nice, especially as you can easily acess them and (apparently) mutate them in place.

Having said all that, I'm not 100% sure this isn't an XY problem. Is your question actually why it is unstable, or are you really implying that you believe this is the cause of a performance problem?

Both actually. Since the code is quite simple for now, I'd like to have most that I can under control. So I'd like to understand what is going on.