FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

MethodError: no method matching +(::Nothing, ::Array{Float32,2}) #454

Closed AzamatB closed 4 years ago

AzamatB commented 4 years ago

I'm trying to implement pyramidal BLSTM in Flux with Zygote backend and getting an error I cannot comprehend. Here is MWE:

using Flux
using Flux: flip, @functor

struct BLSTM{L}
   forward  :: L
   backward :: L
end

@functor BLSTM

(m::BLSTM)(xs) = vcat.(m.forward.(xs), flip(m.backward, xs))

restack(xs) = vcat.(xs[1:2:end], xs[2:2:end])

m = Chain(BLSTM(LSTM(3, 5), LSTM(3, 5)), restack)

xs = [rand(Float32, 3,7) for i ∈ 1:4]
θ = params(m)

gradient(θ) do
   sum(sum(m(xs)))
end

where the call to gradient at the end throws:

ERROR: MethodError: no method matching +(::Nothing, ::Array{Float32,2})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:529
  +(::Array, ::Array...) at arraymath.jl:44
  +(::SparseArrays.SparseMatrixCSC, ::Array) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/SparseArrays/src/sparsematrix.jl:1646
  ...
Stacktrace:
 [1] _broadcast_getindex_evalf at ./broadcast.jl:630 [inlined]
 [2] _broadcast_getindex at ./broadcast.jl:603 [inlined]
 [3] getindex at ./broadcast.jl:563 [inlined]
 [4] macro expansion at ./broadcast.jl:909 [inlined]
 [5] macro expansion at ./simdloop.jl:77 [inlined]
 [6] copyto! at ./broadcast.jl:908 [inlined]
 [7] copyto! at ./broadcast.jl:863 [inlined]
 [8] materialize! at ./broadcast.jl:822 [inlined]
 [9] (::Zygote.var"#996#998"{Array{Array{Float32,2},1},Tuple{StepRange{Int64,Int64}}})(::Array{Array{Float32,2},1}) at /Users/Azamat/.julia/packages/Zygote/oPQFy/src/lib/array.jl:38
 [10] (::Zygote.var"#2634#back#992"{Zygote.var"#996#998"{Array{Array{Float32,2},1},Tuple{StepRange{Int64,Int64}}}})(::Array{Array{Float32,2},1}) at /Users/Azamat/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [11] (::typeof(∂(restack)))(::FillArrays.Fill{FillArrays.Fill{Float32,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1,Tuple{Base.OneTo{Int64}}}) at ./untitled-f6a8d27d318bb16bffc7958f7a72f051:14
 [12] applychain at /Users/Azamat/.julia/packages/Flux/oX9Pi/src/layers/basic.jl:30 [inlined]
 [13] (::typeof(∂(applychain)))(::FillArrays.Fill{FillArrays.Fill{Float32,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1,Tuple{Base.OneTo{Int64}}}) at /Users/Azamat/.julia/packages/Zygote/oPQFy/src/compiler/interface2.jl:0
 [14] Chain at /Users/Azamat/.julia/packages/Flux/oX9Pi/src/layers/basic.jl:32 [inlined]
 [15] (::typeof(∂(λ)))(::FillArrays.Fill{FillArrays.Fill{Float32,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1,Tuple{Base.OneTo{Int64}}}) at /Users/Azamat/.julia/packages/Zygote/oPQFy/src/compiler/interface2.jl:0
 [16] #15 at ./untitled-f6a8d27d318bb16bffc7958f7a72f051:22 [inlined]
 [17] (::typeof(∂(#15)))(::Float32) at /Users/Azamat/.julia/packages/Zygote/oPQFy/src/compiler/interface2.jl:0
 [18] (::Zygote.var"#38#39"{Params,Zygote.Context,typeof(∂(#15))})(::Float32) at /Users/Azamat/.julia/packages/Zygote/oPQFy/src/compiler/interface.jl:101
 [19] gradient(::Function, ::Params) at /Users/Azamat/.julia/packages/Zygote/oPQFy/src/compiler/interface.jl:47
 [20] top-level scope at untitled-f6a8d27d318bb16bffc7958f7a72f051:21

Any advice on what is the problem here and how to fix it? I'm happy to prepare PR fixing it

AzamatB commented 4 years ago

Looks like this has to do with broadcating as replacing

restack(xs) = vcat.(xs[1:2:end], xs[2:2:end])

with

restack(xs) = [vcat(xs[i-1], xs[i]) for i ∈ 2:2:lastindex(xs)]

works around this issue

maartenvd commented 4 years ago

exact same problem, same workaround

#pepsline = peps[i,:] # doesn't work

        T = typeof(peps[1,1])
        pepsline = Zygote.Buffer(T[])
        for t = 1:size(peps,2)
            push!(pepsline,peps[i,t])
        end
AzamatB commented 4 years ago

Here is the MRE:

using Zygote
vs = [rand(2,3) for _ ∈ 1:4]
julia> gradient(xs -> sum(sum(vcat.(xs[1:2:end], xs[2:2:end]))), vs)
ERROR: MethodError: no method matching +(::Nothing, ::Array{Float64,2})

Can someone advise what is the problem here?

AzamatB commented 4 years ago

Further reduced to:

using Zygote
julia> gradient(xs -> sum(sum(vcat.(xs[:], xs))), [rand(2), rand(2)])
ERROR: MethodError: no method matching +(::Nothing, ::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}})
oxinabox commented 4 years ago

Looks like a bug in the custom adjoint for sum, I haven’t looked at the code, but from my logic.

This would not happen if Zygote was using ChainRules.jl’s differential types. Since it would use Zero() rather than nothing and `Zero has + defined on it.

But right now that custom adjoint need to be written to know some args might be nothing (which represents strong zero).

AzamatB commented 4 years ago

This issue is affecting me gravely. Any suggestions on how to fix it?

oxinabox commented 4 years ago

Type pirate +(:: Nothing, x) = x +(x, :: Nothing) = x and +(:: Nothing, ::Nothing) = nothing

ianfiske commented 4 years ago

I've been hitting this also, and @oxinabox's type piracy hack is working for now. Thanks.

AzamatB commented 4 years ago

Seems fixed on master, so will close.