FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

bug with reshape, splatting, and fill #866

Open CarloLucibello opened 3 years ago

CarloLucibello commented 3 years ago

While working at https://github.com/FluxML/NNlib.jl/pull/260 I hit a bug on Zygote master that I managed to reduce to the following

julia> f(x) = reshape(x, fill(2, 2)...)
f (generic function with 1 method)

julia> gradient(x->sum(f(x)), rand(4))
ERROR: MethodError: no method matching +(::Nothing, ::Nothing)
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:538
  +(::ChainRulesCore.One, ::Any) at /home/carlo/.julia/packages/ChainRulesCore/cpHLu/src/differential_arithmetic.jl:94
  +(::ChainRulesCore.Zero, ::Any) at /home/carlo/.julia/packages/ChainRulesCore/cpHLu/src/differential_arithmetic.jl:63
  ...
Stacktrace:
 [1] sum(::Tuple{Nothing,Nothing}) at ./tuple.jl:396
 [2] (::Zygote.var"#392#394"{Tuple{Int64}})(::Tuple{Nothing,Nothing}) at /home/carlo/.julia/packages/Zygote/ggM8Z/src/lib/array.jl:72
 [3] (::Zygote.var"#2317#back#396"{Zygote.var"#392#394"{Tuple{Int64}}})(::Tuple{Nothing,Nothing}) at /home/carlo/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [4] f at ./REPL[39]:1 [inlined]
 [5] (::typeof(∂(f)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /home/carlo/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
 [6] #145 at ./REPL[40]:1 [inlined]
 [7] (::Zygote.var"#41#42"{typeof(∂(#145))})(::Float64) at /home/carlo/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:40
 [8] gradient(::Function, ::Array{Float64,1}) at /home/carlo/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:49
 [9] top-level scope at REPL[40]:1
 [10] run_repl(::REPL.AbstractREPL, ::Any) at /build/julia/src/julia-1.5.3/usr/share/julia/stdlib/v1.5/REPL/src/REPL.jl:288

Notice that the equivalent is handled properly

f(x) = reshape(x,  [2, 2]...)

I could not reduce it further to understand which one is the problematic ingredient

DhairyaLGandhi commented 3 years ago

Likely comes from not using the fill

CarloLucibello commented 3 years ago

different error with repeat

julia> f(x) = reshape(x, repeat([2], 2)...)
f (generic function with 1 method)

julia> gradient(x->sum(f(x)), rand(4))
ERROR: MethodError: no method matching reshape(::Tuple{Nothing,Nothing}, ::Int64, ::Colon)
Closest candidates are:
  reshape(::FillArrays.AbstractFill, ::Union{Colon, Int64}...) at /home/carlo/.julia/packages/FillArrays/gPRiS/src/FillArrays.jl:211
  reshape(::AbstractArray, ::Union{Colon, Int64}...) at reshapedarray.jl:117
  reshape(::FillArrays.AbstractFill, ::Union{Colon, Integer}...) at /home/carlo/.julia/packages/FillArrays/gPRiS/src/FillArrays.jl:212
  ...
Stacktrace:
 [1] (::Zygote.var"#491#492"{Array{Int64,1}})(::Tuple{Nothing,Nothing}) at /home/carlo/.julia/dev/Zygote/src/lib/array.jl:146
 [2] (::Zygote.var"#2515#back#493"{Zygote.var"#491#492"{Array{Int64,1}}})(::Tuple{Nothing,Nothing}) at /home/carlo/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [3] f at ./REPL[8]:1 [inlined]
 [4] (::typeof(∂(f)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /home/carlo/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [5] #7 at ./REPL[9]:1 [inlined]
 [6] (::typeof(∂(#7)))(::Float64) at /home/carlo/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [7] (::Zygote.var"#41#42"{typeof(∂(#7))})(::Float64) at /home/carlo/.julia/dev/Zygote/src/compiler/interface.jl:40
 [8] gradient(::Function, ::Array{Float64,1}) at /home/carlo/.julia/dev/Zygote/src/compiler/interface.jl:49
 [9] top-level scope at REPL[9]:1
 [10] run_repl(::REPL.AbstractREPL, ::Any) at /build/julia/src/julia-1.5.3/usr/share/julia/stdlib/v1.5/REPL/src/REPL.jl:288
CarloLucibello commented 3 years ago

Likely comes from not using the fill

do you know any workaround?

DhairyaLGandhi commented 3 years ago

I would check with https://github.com/FluxML/Zygote.jl/pull/846 to see if we hit the map adjoint.

The solution might be to retain some adjoints in Zygote rather than ChainRules for robustness. I am fairly sure we could handle these cases elegantly before.

Can you add these tests in a PR?

CarloLucibello commented 3 years ago

unfortunately, #846 doesn't help (tried both fill and repeat)

Can you add these tests in a PR?

a PR doing what? I wouldn't know how to fix this

DhairyaLGandhi commented 3 years ago

That adds breaking tests. I don't want this to be missed

CarloLucibello commented 1 year ago

The example in the OP works now

julia> f(x) = reshape(x, fill(2, 2)...)
f (generic function with 1 method)

julia> gradient(x->sum(f(x)), rand(4))
(Fill(1.0, 4),)

The one in https://github.com/FluxML/Zygote.jl/issues/866#issuecomment-752969193 is still failing