Open mzgubic opened 3 years ago
Dug a bit further into this.
The immediate error is solved by adding a method for
grad_mut(x::AnalyticWeights) = Ref{Any}(nt_nothing(x))
which previously dispatched to the AbstractArray
method.
After adding this the error is
julia> gradient(values -> AnalyticWeights(values).sum, [1.0, 2.0, 3.0])
ERROR: MethodError: no method matching (::Zygote.var"#AnalyticWeights_pullback#117")(::Base.RefValue{Any})
Closest candidates are:
(::Zygote.var"#AnalyticWeights_pullback#117")(::AbstractArray) at /Users/mzgubic/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/lib/lib.jl:12
(::Zygote.var"#AnalyticWeights_pullback#117")(::ChainRulesCore.Tangent) at /Users/mzgubic/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/lib/lib.jl:13
(::Zygote.var"#AnalyticWeights_pullback#117")(::ChainRulesCore.AbstractThunk) at /Users/mzgubic/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/lib/lib.jl:14
Stacktrace:
[1] (::Zygote.ZBack{Zygote.var"#AnalyticWeights_pullback#117"})(dy::Base.RefValue{Any})
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/chainrules.jl:141
[2] Pullback
@ ./REPL[4]:1 [inlined]
[3] (::typeof(∂(#5)))(Δ::Float64)
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
[4] (::Zygote.var"#50#51"{typeof(∂(#5))})(Δ::Float64)
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface.jl:41
[5] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface.jl:76
[6] top-level scope
@ REPL[4]:1
which means it looks like the Ref
managed to propagate to the chain rule.
It seemed related to https://github.com/FluxML/Zygote.jl/issues/685, maybe the problem is related to the fact that AnalyticWeights
is mutable? Ok, let me check by defining my own custom type
julia> mutable struct MyWeights
values
sum
end
julia> MyWeights(values) = MyWeights(values, sum(values))
But that seems to work fine:
julia> gradient(values -> MyWeights(values).sum, [1.0, 2.0, 3.0])
(Fill(1.0, 3),)
I think we actually need to accept Ref
in the chainrules conversion functions.
Zygote seems to insert them sometimes.
I think it is a vestigial organ left-over from when Zygote supported mutation.
What exactly is the purpose of grad_mut
? Can you please explain (or add a comment in the source code)?
I have a very similar issue with a custom struct, that is subtype of AbstractArray
.
I think it is a vestigial organ left-over from when Zygote supported mutation.
I am not sure anyone really understands it anymore
It is unfortunately a requirement for diffing through code which uses mutable structs (note, not mutable arrays, which are unsupported). It consists of a more primitive version of the more robust and documented mechanism in https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626, along with a cache to store accumulated gradients of mutable structs. I would love to get rid of it, but a lot of code in the wild relies on it working. Which is unfortunate, because as this issue demonstrates it has been a very leaky and bug-prone abstraction.
Ahh, I see. Thanks for the replies!
However the function grad_mut
is still part of the setindex!
adjoint method here. In my case (mutable struct, subtype AbstractArray), g=grad_mut
returns an empty array, which is the source of the error, because g[]
is therefore empty. So I think my fix would be the add a dispatch for grad_mut
for my custom struct, but I have really no idea what I should return to get it working...
This is only the case if we subtype with AbstractArray
, see this little MWE with two structs, one of subtype AbstractArray
, the other one not:
mutable struct STR1
a::AbstractArray{Real, 1}
b::AbstractArray{Real, 1}
end
mutable struct STR2 <: AbstractArray{Real, 1}
a::AbstractArray{Real, 1}
b::AbstractArray{Real, 1}
end
# some methods so that AbstractArrays are displayed correctly
Base.size(s::STR2) = (length(s.a)+length(s.b),)
Base.getindex(s::STR2, ind::CartesianIndex{A}) where {A} = Base.getindex(s, ind.I[1])
Base.getindex(s::STR2, ind) = (ind <= length(s.a) ? s.a[ind] : s.b[ind-length(s.a)])
str1 = STR1([1.0], [2.0])
str2 = STR2([1.0], [2.0])
Zygote.jacobian(x -> str1.a .- x, 1.0) # works nice!
Zygote.jacobian(x -> str2.a .- x, 1.0) # throws error, access 0-element Vector
Resulting in:
ERROR: BoundsError: attempt to access 0-element Vector{Any} at index []
Stacktrace:
[1] throw_boundserror(A::Vector{Any}, I::Tuple{})
@ Base .\abstractarray.jl:744
[2] checkbounds
@ .\abstractarray.jl:709 [inlined]
[3] _getindex
@ .\abstractarray.jl:1328 [inlined]
[4] getindex(::Vector{Any})
@ Base .\abstractarray.jl:1296
[5] (::Zygote.var"#back#302"{:a, Zygote.Context{false}, STR2, Float64})(Δ::Float64)
@ Zygote C:\Users\...\.julia\packages\Zygote\4SSHS\src\lib\lib.jl:237
[6] #2184#back
@ C:\Users\...\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:71 [inlined]
Because of literal_getfield
at this line. It's the very same as for setindex!
: The result for grad_mut
is empty, but not treated as such.
Ah yes, https://github.com/FluxML/Zygote.jl/blob/cf7f7d08705d2787fa31bcf45bcca5447fd9a9a7/src/lib/base.jl#L3 is likely way too broad of a dispatch. I'm not sure if we can change it without breaking anything, but you could test to see if limiting it to Vector
helps with your use case.
I tried the following line (so I don't destroy anything):
grad_mut(s::STR2) = invoke(grad_mut, Tuple{Any}, s)
Indeed the error is gone, looks good! Thank you very much!
I've encountered this in the wild, but here is a minimal breaking example:
(It's possible to get around it by using
w -> sum(w)
(by adding an rrule for a constructor, see below), but the original example isgradient(std, rand(3), AnalyticWeights([1.0, 2.0, 3.0]))
which calls the weights under the hood)