FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
72 stars 20 forks source link

Restructure is not type stable but could be made stable? #177

Closed Red-Portal closed 2 months ago

Red-Portal commented 2 months ago

Hi, it seems like calling restructure is not stable by default. This is currently causing issues with Enzyme.jl (see this issue). Here is a MWE to illustrate the point:

using Cthulhu, LinearAlgebra, Optimisers, Functors

struct Model{A,B}
    a::A
    b::B
end

Functors.@functor Model

m = Model(randn(10), LowerTriangular(Matrix(I, 10, 10)))

params, re = Optimisers.destructure(m)

@code_warntype re(params)

This returns:

MethodInstance for (::Optimisers.Restructure{Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}, @NamedTuple{a::Int64, b::Tuple{}}})(::Vector{Float64})
  from (re::Optimisers.Restructure)(flat::AbstractVector) @ Optimisers ~/.julia/packages/Optimisers/yDIWk/src/destructure.jl:59
Arguments
  re::Optimisers.Restructure{Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}, @NamedTuple{a::Int64, b::Tuple{}}}
  flat::Vector{Float64}
Body::Model
1 ─ %1 = Base.getproperty(re, :model)::Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}
│   %2 = Base.getproperty(re, :offsets)::@NamedTuple{a::Int64, b::Tuple{}}
│   %3 = Base.getproperty(re, :length)::Int64
│   %4 = Optimisers._rebuild(%1, %2, flat, %3)::Model
└──      return %4

where the return type Model is not stable in terms of its type parameters.

This can be solved in a brute-force manner by defining re::Restructure (defined here) as:

   (re::Restructure)(flat::AbstractVector)::typeof(re.model) = _rebuild(re.model, re.offsets, flat, re.length)

where we are informing the compiler that the return type will be the same as re.model. I think this is safe to assume, and this immediately resolves the instability. Any thoughts on this?

ToucheSir commented 2 months ago

Can you elaborate on why you want re(params) to be type stable? If it's to ensure that subsequent code is type stable, then re(params)::typeof(m) in user code might work better. If it's to ensure that Restructure is internally type stable, a type assertion won't be enough.

Red-Portal commented 2 months ago

Hi @ToucheSir ! The first reason is because Enzyme fails due to the instability. And the second reason is that I find it unusual to be type instable. Restructure is fully aware of the type we expect, so I am actually surprised that the current implementation is not type stable out of the box.

ToucheSir commented 2 months ago

Restructure is fully aware of the type we expect

Not quite. You are allowed to pass an array of a different eltype to Restructure, which might give you a different model return type.

I am actually surprised that the current implementation is not type stable out of the box.

...but even if you weren't, the type instability would remain. The reason is because Functors uses an untyped IdDict internally to keep track of shared/aliased parameters: https://github.com/FluxML/Functors.jl/blob/2eddcb74f9589e61b847362c2a91d91bd90ef628/src/walks.jl#L170-L201

The fix here would be to tell fmap not to cache by passing cache = nothing. This isn't done by default because we can't assume people are using models with zero parameter sharing/tying, but perhaps the models you're working with can. Disabling the cache would be necessary but may not be sufficient to remove type instabilities, as the Julia compiler really hates nested recursive calls like Functors.jl uses.

P.S. since you're working with Enzyme and not Zygote, you may be interested in adapting the mutating (de|Re)structure implementation in https://github.com/FluxML/Optimisers.jl/pull/165 for your use case.

Red-Portal commented 2 months ago

Thanks for the tips. Let me try to explicitly set the return type of the reconstruct expression as you suggested.