Closed cossio closed 2 years ago
Looks like your example is not reproducible due to typos. Can you edit it?
@AzamatB Sorry. I corrected the example.
one can do some little progress by defining the getproperty
method, but then a new issue arises:
julia> Base.getproperty(x::NamedTuple, i::Int) = getfield(x, i)
julia> using Zygote, StructArrays
julia> struct A
x::Float64
end
julia> function f(X)
S = StructArray{A}((X,))
sum(S.x)
end
f (generic function with 1 method)
julia> gradient(f, randn(2))
ERROR: ArgumentError: type does not have a definite number of fields
Stacktrace:
[1] fieldcount(::Any) at ./reflection.jl:705
[2] fieldnames(::DataType) at ./reflection.jl:172
[3] #s54#179(::Any, ::Any) at /home/carlo/.julia/packages/Zygote/4tJp5/src/lib/lib.jl:187
[4] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:526
[5] grad_mut(::Type{T} where T) at /home/carlo/.julia/packages/Zygote/4tJp5/src/lib/lib.jl:218
[6] grad_mut(::Zygote.Context, ::Type{T} where T) at /home/carlo/.julia/packages/Zygote/4tJp5/src/lib/lib.jl:225
[7] (::Zygote.var"#back#183"{:parameters,Zygote.Context,DataType,Core.SimpleVector})(::Tuple{Nothing}) at /home/carlo/.julia/packages/Zygote/4tJp5/src/lib/lib.jl:198
[8] (::Zygote.var"#355#back#184"{Zygote.var"#back#183"{:parameters,Zygote.Context,DataType,Core.SimpleVector}})(::Tuple{Nothing}) at /home/carlo/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[9] tuple_type_tail at ./essentials.jl:223 [inlined]
[10] (::typeof(∂(tuple_type_tail)))(::Nothing) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
[11] index_type at /home/carlo/.julia/packages/StructArrays/2PoXh/src/structarray.jl:24 [inlined]
[12] (::typeof(∂(index_type)))(::Nothing) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
[13] index_type at /home/carlo/.julia/packages/StructArrays/2PoXh/src/structarray.jl:21 [inlined]
[14] (::typeof(∂(index_type)))(::Nothing) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
[15] StructArray at /home/carlo/.julia/packages/StructArrays/2PoXh/src/structarray.jl:17 [inlined]
[16] (::typeof(∂(StructArray{A,1,NamedTuple{(:x,),Tuple{Array{Float64,1}}},I} where I)))(::NamedTuple{(:fieldarrays, :x),Tuple{Nothing,FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}}}) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
[17] StructArray at /home/carlo/.julia/packages/StructArrays/2PoXh/src/structarray.jl:33 [inlined]
[18] (::typeof(∂(StructArray{A,N,C,I} where I where C<:Union{Tuple, NamedTuple} where N)))(::NamedTuple{(:fieldarrays, :x),Tuple{Nothing,FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}}}) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
[19] f at ./REPL[12]:2 [inlined]
[20] (::typeof(∂(f)))(::Float64) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
[21] (::Zygote.var"#36#37"{typeof(∂(f))})(::Float64) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:36
[22] gradient(::Function, ::Array{Float64,1}) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:45
[23] top-level scope at REPL[13]:1
[24] eval(::Module, ::Any) at ./boot.jl:331
[25] eval_user_input(::Any, ::REPL.REPLBackend) at /home/carlo/julia/julia-1.4.0/share/julia/stdlib/v1.4/REPL/src/REPL.jl:86
[26] run_backend(::REPL.REPLBackend) at /home/carlo/.julia/packages/Revise/C272c/src/Revise.jl:1075
[27] top-level scope at none:0
Probably we need an adjoint for the StructArray
constructor
I think so. The issue seems to be that, since StructArray <: AbstractArray
, Zygote is using the differentiation rules for an array. Rather we want to treat StructArray
as a struct, with array fields.
Probably we need an adjoint for the
StructArray
constructor
I gave this a try:
@adjoint function (::Type{T})(t::Tuple) where {T<:StructArray}
back(Δ) = (nothing,Base.tail(values(Δ))...)
return T(t), back
end
But now with this example:
struct A
x::Float64; y::Float64
end
function f(X,Y)
S = StructArray{A}((X,Y))
return sum(S.x) + sum(S.y)
end
@show gradient(f, randn(2), randn(2))
the returned gradient is (nothing, nothing)
. It should be all ones.
Moreover, for some reason the back
function defined above gets called with Δ = (fieldarrays = nothing, y = [1.0, 1.0])
, so that the gradient in x
is already lost!
Finally I got something working:
Zygote.@adjoint function (::Type{T})(t::Tuple) where {T<:StructArray}
back(Δ::NamedTuple) = (values(Δ),)
return T(t), back
end
Zygote.@adjoint function Zygote.literal_getproperty(sa::StructArray, ::Val{key}) where {key}
key::Symbol # only support this
result = getproperty(sa, key)
function back(Δ)
z = (; (p => zero(getproperty(sa, p)) for p in propertynames(sa))...)
return (merge(z, (; key => Δ)), nothing)
end
return result, back
end
struct A
x::Float64; y::Float64
end
function f(X,Y)
S = StructArray{A}((X,Y))
return sum(S.x) + sum(S.y)
end
@show gradient(f, randn(2), randn(2))
Should this be in Zygote?
Fixed in https://github.com/cossio/ZygoteStructArrays.jl.
But leaving this open, since in principle Zygote should just work and differentiate this without needing a rule definition, right?
Update: https://github.com/cossio/ZygoteStructArrays.jl doesn't work with recent versions of StructArrays, due to breaking changes.
I think it would be nice to come up with a solution based on ChainRules.
Now I am getting this error:
julia> using Zygote, StructArrays
julia> struct A
x::Float64
end
julia> function f(X)
S = StructArray{A}((X,))
sum(S.x)
end
f (generic function with 1 method)
julia> gradient(f, randn(2))
ERROR: ArgumentError: type does not have a definite number of fields
Stacktrace:
[1] fieldcount(t::Any)
@ Base ./reflection.jl:764
[2] fieldnames(t::DataType)
@ Base ./reflection.jl:185
[3] #s72#217
@ ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:220 [inlined]
[4] var"#s72#217"(::Any, x::Any)
@ Zygote ./none:0
[5] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:580
[6] grad_mut(x::Type)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:262
[7] grad_mut(cx::Zygote.Context, x::Type)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:269
[8] (::Zygote.var"#back#222"{:parameters, Zygote.Context, DataType, Core.SimpleVector})(Δ::Tuple{Nothing})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:233
[9] (::Zygote.var"#1765#back#223"{Zygote.var"#back#222"{:parameters, Zygote.Context, DataType, Core.SimpleVector}})(Δ::Tuple{Nothing})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[10] Pullback
@ ./Base.jl:37 [inlined]
[11] (::typeof(∂(getproperty)))(Δ::Tuple{Nothing})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[12] Pullback
@ ./tuple.jl:308 [inlined]
[13] (::typeof(∂(tuple_type_tail)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[14] Pullback
@ ~/.julia/packages/StructArrays/bekT9/src/structarray.jl:32 [inlined]
[15] (::typeof(∂(index_type)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[16] Pullback
@ ~/.julia/packages/StructArrays/bekT9/src/structarray.jl:29 [inlined]
[17] (::typeof(∂(index_type)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[18] Pullback
@ ~/.julia/packages/StructArrays/bekT9/src/structarray.jl:24 [inlined]
[19] (::typeof(∂(StructVector{A, NamedTuple{(:x,), Tuple{Vector{Float64}}}})))(Δ::NamedTuple{(:components,), Tuple{NamedTuple{(:x,), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[20] Pullback
@ ~/.julia/packages/StructArrays/bekT9/src/structarray.jl:94 [inlined]
[21] (::typeof(∂(StructArray{A})))(Δ::NamedTuple{(:components,), Tuple{NamedTuple{(:x,), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[22] Pullback
@ ./REPL[3]:2 [inlined]
[23] (::typeof(∂(f)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[24] (::Zygote.var"#57#58"{typeof(∂(f))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:41
[25] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:76
[26] top-level scope
@ REPL[4]:1
Both examples here are now working on Julia 1.8 with:
(@v1.8) pkg> st
Status `~/.julia/environments/v1.8/Project.toml`
[09ab397b] StructArrays v0.6.12
[e88e6eb3] Zygote v0.6.47
Not sure what was the fix. Closing.
Any suggestions on how this case can be tested during CI? I'm not sure we want to add StructArrays as a test dependency, or yes?
If you can bisect the last Zygote version that wasn't working, creating a stripped-down type that raises the same error could work. However, I suspect the fix may have come from changes on the StructArrays side. If that's true, then testing fro AD compat ought to happen there and not here if it's not done already.
Update:
Fixed in https://github.com/cossio/ZygoteStructArrays.jl.
But leaving this open, since in principle Zygote should just work and differentiate this without needing a rule definition, right?
Original issue:
Produces the following error:
Stacktrace:
Edit: Updated stacktrace to Zygote v0.4.17.