FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
72 stars 20 forks source link

`destructure` doesn't work on Dictionaries #154

Open mcabbott opened 1 year ago

mcabbott commented 1 year ago

destructure uses map, I think from before support for Dict was added elsewhere, hence this fails:

julia> d = Dict(
           :a => Dict(
               :b => Dict(
                   :c => 1,
                   :d => 2,
               ),
               :e => 3,
           ), 
           :f => 4,
       )
Dict{Symbol, Any} with 2 entries:
  :a => Dict{Symbol, Any}(:b=>Dict(:d=>2, :c=>1), :e=>3)
  :f => 4

julia> destructure(d)
ERROR: map is not defined on dictionaries
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] map(f::Function, ::Dict{Symbol, Any})
    @ Base ./abstractarray.jl:3303
  [3] (::Optimisers._TrainableStructWalk)(recurse::Function, x::Dict{Symbol, Any})
    @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/destructure.jl:81
  [4] (::Functors.ExcludeWalk{…})(::Function, ::Dict{…})
    @ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:106
  [5] (::Functors.CachedWalk{…})(::Functors.var"#recurse#19"{…}, ::Dict{…})
    @ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:146 [inlined]
  [6] execute(::Functors.CachedWalk{Functors.ExcludeWalk{…}, Functors.NoKeyword}, ::Dict{Symbol, Any})
    @ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:38
  [7] fmap(::Function, ::Dict{…}; exclude::Function, walk::Optimisers._TrainableStructWalk, cache::IdDict{…}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/rlD70/src/maps.jl:11
  [8] _flatten(x::Dict{Symbol, Any})
    @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/destructure.jl:69 [inlined]
  [9] destructure(x::Dict{Symbol, Any})
    @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/destructure.jl:30
CarloLucibello commented 5 months ago

With #174, and in particular the use of mapvalue instead of map, the situation has improved although it is not fixed yet.

julia> d = Dict(
           :a => Dict(
               :b => Dict(
                   :c => [1.],
                   :d => [2.],
               ),
               :e => 3.,
           ),
           :f => [4.],
       )
Dict{Symbol, Any} with 2 entries:
  :a => Dict{Symbol, Any}(:b=>Dict(:d=>[2.0], :c=>[1.0]), :e=>3.0)
  :f => [4.0]

julia> ps, re = destructure(d)
([2.0, 1.0, 4.0], Restructure(Dict, ..., 3))

julia> re(ps)
2-element Vector{Pair{Symbol}}:
 :a => Pair{Symbol}[:b => [:d => [2.0], :c => [1.0]], :e => 3.0]
 :f => [4.0]