FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Error with `Zygote.gradient` for `foldl, sum` #1279

Open vpuri3 opened 2 years ago

vpuri3 commented 2 years ago

MWE

using Zygote, LinearAlgebra

N = 4
u0 = rand(N)
ps = rand(N)

mats = (rand(N,N), rand(N,N),) # (A, B,)
nums = (rand(), rand(),)       # (α, β,)

loss_m = function(p)
    v = Diagonal(p) * u0
    v = Zygote.hook(Δ -> (println("Δv: ", typeof(Δ)); Δ), v)

    w = foldl((acc, op) -> op * acc, mats; init=v) # w = B * A * v
    w = Zygote.hook(Δ -> (println("Δw: ", Δ); Δ), w)

    l = sum(w)
    l = Zygote.hook(Δ -> (println("Δl: ", Δ); Δ), l)
end

println("fwd"); @time loss_m(ps) |> display
println("bwd"); @time Zygote.gradient(loss_m, ps) |> display # INCORRECT - should not vanish

loss_n = function(p)
    v = Diagonal(p) * u0
    v = Zygote.hook(Δ -> (println("Δv: ", typeof(Δ)); Δ), v)

    w = sum(a -> convert(Number, a), nums; init=zero(eltype(nums))) * v # w = αβ * v
    w = Zygote.hook(Δ -> (println("Δw: ", Δ); Δ), w)

    l = sum(w)
    l = Zygote.hook(Δ -> (println("Δl: ", Δ); Δ), l)
end

println("fwd"); @time loss_n(ps) |> display
println("bwd"); @time Zygote.gradient(loss_n, ps) |> display # ERRORS
julia> include("examples/ad/zy.jl")
fwd
4.339451806053281
  0.021413 seconds (44.18 k allocations: 2.637 MiB, 99.38% compilation time)
bwd
Δl: 1.0
Δw: Fill(1.0, 4)
Δv: Nothing
(nothing,)
  0.139943 seconds (444.37 k allocations: 23.545 MiB, 99.45% compilation time)
fwd
1.5660193401267022
  0.355185 seconds (1.11 M allocations: 65.174 MiB, 99.67% compilation time)
bwd
ERROR: LoadError: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:872
  iterate(::Union{LinRange, StepRangeLen}, ::Integer) at range.jl:872
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at dict.jl:712
  ...
Stacktrace:
  [1] indexed_iterate(I::Nothing, i::Int64)
    @ Base ./tuple.jl:91
  [2] chain_rrule_kw
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:229 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0 [inlined]
  [4] _pullback(::Zygote.Context, ::Base.var"#sum##kw", ::NamedTuple{(:init,), Tuple{Float64}}, ::typeof(sum), ::var"#49#54", ::Tuple{Float64, Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:9
  [5] _pullback
    @ ~/.julia/dev/PDEInterfaces/examples/ad/zy.jl:29 [inlined]
  [6] _pullback(ctx::Zygote.Context, f::var"#47#52", args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
  [7] _pullback(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:34
  [8] pullback(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:40
  [9] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:75
 [10] top-level scope
    @ ./timing.jl:242
 [11] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [12] top-level scope
    @ REPL[2]:1
 [13] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
in expression starting at /home/vedantpu/.julia/dev/PDEInterfaces/examples/ad/zy.jl:38

ref - https://github.com/SciML/SciMLOperators.jl/pull/94

vpuri3 commented 2 years ago

the case with sum works when I remove the kwarg init. but still curious why it wouldn't work otherwise

mcabbott commented 2 years ago

foldl not tracking init keyword is https://github.com/JuliaDiff/ChainRules.jl/issues/567, you could try with https://github.com/JuliaDiff/ChainRules.jl/pull/569

sum not supporting init is also bad, could you make an issue on ChainRules.jl?

julia> ChainRules.rrule(sum, [1,2,3]; init=4)
ERROR: MethodError: no method matching rrule(::typeof(sum), ::Vector{Int64}; init::Int64)

Closest candidates are:
  rrule(::typeof(sum), ::AbstractArray; dims) got unsupported keyword argument "init"
   @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/mapreduce.jl:28
  rrule(::typeof(sum), ::Any, ::AbstractArray{Bool}; sum_pullback) got unsupported keyword argument "init"
   @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:82