FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Error from gradient of `vcat(x...)` - appeared in v0.6.45 #1417

Open danielalcalde opened 1 year ago

danielalcalde commented 1 year ago

In https://github.com/GTorlai/PastaQ.jl/issues/300#issuecomment-1525720876 a bug was detected that I have found to stem from a problem in the differentiation of vcat. I created a minimal example to reproduce the error:

using Zygote
function loss(theta)
    x1 = vcat([theta], 5)
    x2 = vcat(x1...)
    return x2[1]
end
println(gradient(loss, 1))
ERROR: LoadError: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(::Tuple{Float64})
Closest candidates are:
  (::ChainRulesCore.ProjectTo{AbstractArray})(::Union{LinearAlgebra.Adjoint{T, var"#s886"}, LinearAlgebra.Transpose{T, var"#s886"}} where {T, var"#s886"<:(AbstractVector)}) at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:247
  (::ChainRulesCore.ProjectTo{AbstractArray})(::AbstractArray{<:ChainRulesCore.AbstractZero}) at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:244
  (::ChainRulesCore.ProjectTo{AbstractArray})(::AbstractArray{S, M}) where {S, M} at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:219
  ...
Stacktrace:
  [1] (::ChainRules.var"#1413#1419"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}})()
    @ ChainRules ~/.julia/packages/ChainRules/aKxNz/src/rulesets/Base/array.jl:310
  [2] unthunk
    @ ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:204 [inlined]
  [3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1413#1419"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}}}, ChainRules.var"#1412#1418"{Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:237
  [4] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:110 [inlined]
  [5] map
    @ ./tuple.jl:223 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:111 [inlined]
  [7] (::Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}})(dy::Tuple{Float64, Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:211
  [8] Pullback
    @ ~/workprojects/education/julia/pastaq/break.jl:3 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(loss), Int64}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}, Tuple{}}, Val{1}}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1369"{1, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}}}}})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(loss), Int64}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}, Tuple{}}, Val{1}}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1369"{1, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}}}}}})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [11] gradient(::Function, ::Int64, ::Vararg{Int64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97

this code used to work in version Zygote@0.6.44 but does not work as early as Zygote@0.6.45 until Zygote@0.6.60.

ToucheSir commented 1 year ago

0.6.45 is when we switched over to ChainRules for the cat functions: https://github.com/FluxML/Zygote.jl/pull/1277. TBD whether ChainRules projection isn't being flexible enough or if Zygote is passing invalid inputs to it.

mcabbott commented 1 year ago

I think this is https://github.com/FluxML/Zygote.jl/issues/599, x1... makes a Tuple but the gradient of x1 ought to be an array. It's been worked around in some cases (e..g with _project) but not all.

theabhirath commented 1 year ago

I am running into this issue while trying to implement DenseNet. Since vcat is one of the only non-mutating ways to append elements to arrays, this is a blocker for that. Is there a workaround or a fix for this? I confirmed that it was working on 0.6.44, but the error appears on versions higher than that.

mcabbott commented 1 year ago

Can you simplify the example, or make other ones? Perhaps the splat isn't the right diagnosis, as things like this seem fine:

julia> gradient([2, 3.0]) do x
         vcat(x...)[1]
       end
([1.0, 0.0],)

julia> gradient([2, 3.0]) do x
         vcat(x..., 4)[1]
       end
([1.0, 0.0],)

No, I think those are getting fixed... pullback avoids a final _project on the answer of gradient, here the splat clearly makes the tuple:

julia> pullback([2, 3.0]) do x
         vcat(x...)[1]
       end[2](1.0)
((1.0, 0.0),)

The rrule involved cannot fix this, it sees and returns individual arguments:

julia> using ChainRules, ChainRulesCore

julia> rrule(vcat, 3/4, 4/5)[2]([6.6, 7.7])
(NoTangent(), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)))

julia> unthunk.(ans)
(NoTangent(), 6.6, 7.7)
theabhirath commented 1 year ago

Differentiating through this is what causes the error for me:

function (m::DenseBlock)(x)
    input = [x]
    for layer in m.layers
        x = layer(input)
        input = vcat(input, [x])
    end
    return cat_channels(input...)
end

This is the only place vcat is used in my code. The layers are mostly simple Chains with Convs and BatchNorms, in case that is useful information. It does seem to suggest that the splat is not the only issue.

ToucheSir commented 1 year ago

I think we should be implementing DenseNet differently anyhow (toss up a PR if you want some ideas there), so this shouldn't block Metalhead at least.