Closed zuhengxu closed 1 year ago
The issue is not the struct but the (type unstable?) dictionary lookup. See the backtrace for more information.
In any case, duplicate of https://github.com/EnzymeAD/Enzyme.jl/issues/416
Even if we remove the caching by overloading some of the Optimisers.jl functionality, it still fails, though for seemingly different reasons:
julia> using Enzyme, Functors, Optimisers
julia> function Optimisers._flatten(x)
Optimisers.isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = Optimisers.isnumeric, walk = Optimisers._TrainableStructWalk(), cache = nothing) do y
push!(arrays, Optimisers._vec(y))
o = len[]
len[] = o + length(y)
o
end
isempty(arrays) && return Bool[], off, 0
reduce(vcat, arrays), off, len[]
end
julia> function Optimisers._rebuild(x, off, flat::AbstractVector, len = length(flat); walk = Optimisers._Trainable_biwalk(), kw...)
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
fmap(x, off; exclude = Optimisers.isnumeric, walk, cache = nothing, kw...) do y, o
Optimisers._getat(y, o, flat)
end
end
julia> struct MyShift{T}
a::T
end
julia> Functors.@functor MyShift
julia> (s::MyShift)(x) = x .+ s.a
julia> s = MyShift(ones(2))
MyShift{Vector{Float64}}([1.0, 1.0])
julia> x = randn(2)
2-element Vector{Float64}:
-0.19679963521399926
1.295242450296637
julia> # `destructure` collects all the trainable parameters in a vector, and returns this along with a function to re-build a similar structure from the vector
ps, restructure = Optimisers.destructure(s)
([1.0, 1.0], Restructure(MyShift, ..., 2))
julia> func(ps_) = sum(abs2, restructure(ps_)(ones(2)))
func (generic function with 1 method)
julia> θ = randn(2)
2-element Vector{Float64}:
0.55035203569321
-0.49582453100445745
julia> ∇θ = zeros(2)
2-element Vector{Float64}:
0.0
0.0
julia> func(θ)
2.657784338114955
julia> Enzyme.API.runtimeActivity!(true)
julia> Enzyme.autodiff(Enzyme.ReverseWithPrimal, func, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))
warning: didn't implement memmove, using memcpy as fallback which can result in errors
((nothing,), 2.657784338114955)
Even if we remove the caching by overloading some of the Optimisers.jl functionality, it still fails, though for seemingly different reasons:
What do you mean it still fails? The code seems to run to completion. Is the gradient wrong?
@torfjelde what do yo mean by failed, it looks like it ran?
Edit: lol jinx on response race condition @vchuravy
Haha sorry, yeah the gradient is incorrect. Should be ones(2), no?
In reverse mode, duplicated means updates the derivative in place. Did you check the value of ∇θ
after the run?
The 2.67 number is the result of the function func (since you requested ReverseWithPrimal)
Yep, for a separate run
julia> θ = randn(2)
2-element Vector{Float64}:
1.3248424332978874
-0.4331324303513132
... # running the gradient function
julia> ∇θ
2-element Vector{Float64}:
4.649684866595775
1.1337351392973736
This is correct, because the gradient should be [2(θ[1]+1), 2(θ[2]+1)]
Aaah perfect; sorry, didn't think too much about what was going on but just gave a quick attempt at a fix.Thank you so much for the help @wsmoses and @vchuravy
Example:
error message:
Looks like
Enzyme
might have trouble dealing with closures? I'm not certain byt this could be weakly related to #700 .System info: