Open danielalcalde opened 1 year ago
0.6.45 is when we switched over to ChainRules for the cat functions: https://github.com/FluxML/Zygote.jl/pull/1277. TBD whether ChainRules projection isn't being flexible enough or if Zygote is passing invalid inputs to it.
I think this is https://github.com/FluxML/Zygote.jl/issues/599, x1...
makes a Tuple but the gradient of x1
ought to be an array. It's been worked around in some cases (e..g with _project
) but not all.
I am running into this issue while trying to implement DenseNet. Since vcat
is one of the only non-mutating ways to append elements to arrays, this is a blocker for that. Is there a workaround or a fix for this? I confirmed that it was working on 0.6.44, but the error appears on versions higher than that.
Can you simplify the example, or make other ones? Perhaps the splat isn't the right diagnosis, as things like this seem fine:
julia> gradient([2, 3.0]) do x
vcat(x...)[1]
end
([1.0, 0.0],)
julia> gradient([2, 3.0]) do x
vcat(x..., 4)[1]
end
([1.0, 0.0],)
No, I think those are getting fixed... pullback
avoids a final _project
on the answer of gradient
, here the splat clearly makes the tuple:
julia> pullback([2, 3.0]) do x
vcat(x...)[1]
end[2](1.0)
((1.0, 0.0),)
The rrule
involved cannot fix this, it sees and returns individual arguments:
julia> using ChainRules, ChainRulesCore
julia> rrule(vcat, 3/4, 4/5)[2]([6.6, 7.7])
(NoTangent(), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)))
julia> unthunk.(ans)
(NoTangent(), 6.6, 7.7)
Differentiating through this is what causes the error for me:
function (m::DenseBlock)(x)
input = [x]
for layer in m.layers
x = layer(input)
input = vcat(input, [x])
end
return cat_channels(input...)
end
This is the only place vcat
is used in my code. The layer
s are mostly simple Chain
s with Conv
s and BatchNorm
s, in case that is useful information. It does seem to suggest that the splat
is not the only issue.
I think we should be implementing DenseNet differently anyhow (toss up a PR if you want some ideas there), so this shouldn't block Metalhead at least.
In https://github.com/GTorlai/PastaQ.jl/issues/300#issuecomment-1525720876 a bug was detected that I have found to stem from a problem in the differentiation of vcat. I created a minimal example to reproduce the error:
this code used to work in version Zygote@0.6.44 but does not work as early as Zygote@0.6.45 until Zygote@0.6.60.