JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
435 stars 89 forks source link

ambiguous rrule for sum of AbstractArray{Bool} #765

Closed nomadbl closed 8 months ago

nomadbl commented 10 months ago

This might be considered an edge case because of the Boolean type, but I'm leaving it here as an FYI in case it turns out to matter. This used to work (gives gradient of nothing) under 1.9.0, but under 1.10 it does not.

MWE:

julia> using Flux, StatsBase
       loss(x) = mean(Flux.onecold(Flux.onehotbatch(x, 1:10)) .== rand(1:10, 10))
       x = rand(1:10, 10)
       Flux.gradient(loss, x)
ERROR: MethodError: kwcall(::@NamedTuple{dims::Colon}, ::typeof(ChainRulesCore.rrule), ::typeof(sum), ::BitVector) is ambiguous.

Candidates:
  kwcall(::NamedTuple, ::typeof(ChainRulesCore.rrule), ::typeof(sum), x::AbstractArray)
    @ ChainRules ~/.julia/packages/ChainRules/pEOSw/src/rulesets/Base/mapreduce.jl:28
  kwcall(kwargs, ::typeof(ChainRulesCore.rrule), ::typeof(sum), var"517"::AbstractArray{Bool})
    @ ChainRules none:0

Possible fix, define
  kwcall(::NamedTuple, ::typeof(ChainRulesCore.rrule), ::typeof(sum), ::AbstractArray{Bool})

Stacktrace:
  [1] rrule(::typeof(mean), x::BitVector; dims::Function)
    @ ChainRules ~/.julia/packages/ChainRules/pEOSw/src/rulesets/Statistics/statistics.jl:12
  [2] rrule(::typeof(mean), x::BitVector)
    @ ChainRules ~/.julia/packages/ChainRules/pEOSw/src/rulesets/Statistics/statistics.jl:11
  [3] rrule(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::BitVector)
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/rules.jl:134
  [4] chain_rrule
    @ ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:223 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0 [inlined]
  [6] _pullback(ctx::Zygote.Context{false}, f::typeof(mean), args::BitVector)
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:81
  [7] loss
    @ ./REPL[6]:1 [inlined]
  [8] _pullback(ctx::Zygote.Context{false}, f::typeof(loss), args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
  [9] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:44
 [10] pullback(f::Function, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:42
 [11] gradient(::Function, ::Vector{Int64}, ::Vararg{Vector{Int64}})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:96
 [12] top-level scope
    @ REPL[12]:1

System info

julia> versioninfo()
Julia Version 1.10.0
Commit c67ed11612* (2023-12-26 21:52 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i5-1035G1 CPU @ 1.00GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, icelake-client)
  Threads: 1 on 8 virtual cores
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64

(@v1.10) pkg> status Flux
Status `~/.julia/environments/v1.10/Project.toml`
  [587475ba] Flux v0.14.7

side note - this 1.10 is my own build with 2 lines differing from 1.10 in linear algebra, nothing to do with Flux, sum, mean....

oxinabox commented 10 months ago

I suspect this might apply more broadly to anything declared with @non_differentiable and might point to a change needed in that macro in ChainRulesCore.

nomadbl commented 8 months ago

The MWE is now working as expected. Thanks for the fix!