FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

error with map over NamedTuple #1072

Open CarloLucibello opened 3 years ago

CarloLucibello commented 3 years ago
julia> x = (; a=1, b=2)
(a = 1, b = 2)

julia> map(sqrt, x)
(a = 1.0, b = 1.4142135623730951)

julia> gradient(x ->  map(sqrt, x).a, x)
ERROR: MethodError: no method matching lastindex(::Nothing)
Closest candidates are:
  lastindex(::Any, ::Any) at abstractarray.jl:348
  lastindex(::Union{DataStructures.SortedDict, DataStructures.SortedMultiDict, DataStructures.SortedSet}) at /home/carlo/.julia/packages/DataStructures/nBjdy/src/tokens2.jl:19
  lastindex(::Union{ArrayInterface.BidiagonalIndex, ArrayInterface.TridiagonalIndex, ArrayInterface.BandedBlockBandedMatrixIndex, ArrayInterface.BandedMatrixIndex, ArrayInterface.BlockBandedMatrixIndex}) at /home/carlo/.julia/packages/ArrayInterface/61qJ7/src/array_index.jl:208
  ...
Stacktrace:
  [1] last(a::Nothing)
    @ Base ./abstractarray.jl:437
  [2] (::Zygote.var"#568#574")(::Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/array.jl:211
  [3] map
    @ ./tuple.jl:233 [inlined]
  [4] (::Zygote.var"#567#573"{typeof(sqrt), Tuple{Tuple{Int64, Int64}}, Tuple{Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}})(Δ::Tuple{Float64, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/array.jl:211
  [5] (::Zygote.var"#2616#back#577"{Zygote.var"#567#573"{typeof(sqrt), Tuple{Tuple{Int64, Int64}}, Tuple{Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}}})(Δ::Tuple{Float64, Nothing})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [6] (::Zygote.var"#213#214"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2616#back#577"{Zygote.var"#567#573"{typeof(sqrt), Tuple{Tuple{Int64, Int64}}, Tuple{Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}}}})(Δ::Tuple{Float64, Nothing})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:203
  [7] (::Zygote.var"#1754#back#215"{Zygote.var"#213#214"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2616#back#577"{Zygote.var"#567#573"{typeof(sqrt), Tuple{Tuple{Int64, Int64}}, Tuple{Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}}}}})(Δ::Tuple{Float64, Nothing})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [8] Pullback
    @ ./namedtuple.jl:197 [inlined]
  [9] (::typeof(∂(map)))(Δ::NamedTuple{(:a, :b), Tuple{Float64, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[60]:1 [inlined]
 [11] (::Zygote.var"#50#51"{typeof(∂(#53))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
 [12] gradient(f::Function, args::NamedTuple{(:a, :b), Tuple{Int64, Int64}})
    @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:76
 [13] top-level scope
    @ REPL[60]:1
 [14] top-level scope
    @ ~/.julia/packages/CUDA/9T5Sq/src/initialization.jl:66
mcabbott commented 2 years ago

This now seems to work, maybe should become a test:

julia> x = (; a=1, b=2)
(a = 1, b = 2)

julia> map(sqrt, x)
(a = 1.0, b = 1.4142135623730951)

julia> gradient(x ->  map(sqrt, x).a, x)
((a = 0.5, b = nothing),)

(@v1.9) pkg> st Zygote ChainRules
Status `~/.julia/environments/v1.9/Project.toml`
⌃ [082447d4] ChainRules v1.39.0
  [e88e6eb3] Zygote v0.6.41