FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

getproperty MethodError with StructArrays #602

Closed cossio closed 2 years ago

cossio commented 4 years ago

Update:

Fixed in https://github.com/cossio/ZygoteStructArrays.jl.

But leaving this open, since in principle Zygote should just work and differentiate this without needing a rule definition, right?

Original issue:

using Zygote, StructArrays
struct A
    x::Float64
end
function f(X)
    S = StructArray{A}((X,))
    sum(S.x)
end
gradient(f, randn(2))

Produces the following error:

ERROR: MethodError: no method matching getproperty(::NamedTuple{(:x,),Tuple{Array{Float64,1}}}, ::Int64)

Stacktrace:

 [1] adjoint at /home/cossio/.julia/packages/Zygote/1aQlT/src/lib/lib.jl:204 [inlined]
 [2] _pullback at /home/cossio/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [3] _pullback(::Zygote.Context, ::typeof(Zygote.literal_getindex), ::NamedTuple{(:x,),Tuple{Array{Float64,1}}}, ::Val{1}) at /home/cossio/.julia/packages/Zygote/1aQlT/src/lib/lib.jl:224
 [4] StructArray at /home/cossio/.julia/packages/StructArrays/2PoXh/src/structarray.jl:32 [inlined]
 [5] _pullback(::Zygote.Context, ::Type{StructArray{A,N,C,I} where I where C<:Union{Tuple, NamedTuple} where N}, ::Tuple{Array{Float64,1}}) at /home/cossio/.julia/packages/Zygote/1aQlT/src/compiler/interface2.jl:0
 [6] f at ./REPL[3]:2 [inlined]
 [7] _pullback(::Zygote.Context, ::typeof(f), ::Array{Float64,1}) at /home/cossio/.julia/packages/Zygote/1aQlT/src/compiler/interface2.jl:0
 [8] _pullback(::Function, ::Array{Float64,1}) at /home/cossio/.julia/packages/Zygote/1aQlT/src/compiler/interface.jl:29
 [9] pullback(::Function, ::Array{Float64,1}) at /home/cossio/.julia/packages/Zygote/1aQlT/src/compiler/interface.jl:35
 [10] gradient(::Function, ::Array{Float64,1}) at /home/cossio/.julia/packages/Zygote/1aQlT/src/compiler/interface.jl:44
 [11] top-level scope at REPL[4]:1

Edit: Updated stacktrace to Zygote v0.4.17.

AzamatB commented 4 years ago

Looks like your example is not reproducible due to typos. Can you edit it?

cossio commented 4 years ago

@AzamatB Sorry. I corrected the example.

CarloLucibello commented 4 years ago

one can do some little progress by defining the getproperty method, but then a new issue arises:

julia> Base.getproperty(x::NamedTuple, i::Int) = getfield(x, i)

julia> using Zygote, StructArrays

julia> struct A
           x::Float64
       end

julia> function f(X)
           S = StructArray{A}((X,))
           sum(S.x)
       end
f (generic function with 1 method)

julia> gradient(f, randn(2))
ERROR: ArgumentError: type does not have a definite number of fields
Stacktrace:
 [1] fieldcount(::Any) at ./reflection.jl:705
 [2] fieldnames(::DataType) at ./reflection.jl:172
 [3] #s54#179(::Any, ::Any) at /home/carlo/.julia/packages/Zygote/4tJp5/src/lib/lib.jl:187
 [4] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:526
 [5] grad_mut(::Type{T} where T) at /home/carlo/.julia/packages/Zygote/4tJp5/src/lib/lib.jl:218
 [6] grad_mut(::Zygote.Context, ::Type{T} where T) at /home/carlo/.julia/packages/Zygote/4tJp5/src/lib/lib.jl:225
 [7] (::Zygote.var"#back#183"{:parameters,Zygote.Context,DataType,Core.SimpleVector})(::Tuple{Nothing}) at /home/carlo/.julia/packages/Zygote/4tJp5/src/lib/lib.jl:198
 [8] (::Zygote.var"#355#back#184"{Zygote.var"#back#183"{:parameters,Zygote.Context,DataType,Core.SimpleVector}})(::Tuple{Nothing}) at /home/carlo/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [9] tuple_type_tail at ./essentials.jl:223 [inlined]
 [10] (::typeof(∂(tuple_type_tail)))(::Nothing) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [11] index_type at /home/carlo/.julia/packages/StructArrays/2PoXh/src/structarray.jl:24 [inlined]
 [12] (::typeof(∂(index_type)))(::Nothing) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [13] index_type at /home/carlo/.julia/packages/StructArrays/2PoXh/src/structarray.jl:21 [inlined]
 [14] (::typeof(∂(index_type)))(::Nothing) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [15] StructArray at /home/carlo/.julia/packages/StructArrays/2PoXh/src/structarray.jl:17 [inlined]
 [16] (::typeof(∂(StructArray{A,1,NamedTuple{(:x,),Tuple{Array{Float64,1}}},I} where I)))(::NamedTuple{(:fieldarrays, :x),Tuple{Nothing,FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}}}) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [17] StructArray at /home/carlo/.julia/packages/StructArrays/2PoXh/src/structarray.jl:33 [inlined]
 [18] (::typeof(∂(StructArray{A,N,C,I} where I where C<:Union{Tuple, NamedTuple} where N)))(::NamedTuple{(:fieldarrays, :x),Tuple{Nothing,FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}}}) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [19] f at ./REPL[12]:2 [inlined]
 [20] (::typeof(∂(f)))(::Float64) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#36#37"{typeof(∂(f))})(::Float64) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:36
 [22] gradient(::Function, ::Array{Float64,1}) at /home/carlo/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:45
 [23] top-level scope at REPL[13]:1
 [24] eval(::Module, ::Any) at ./boot.jl:331
 [25] eval_user_input(::Any, ::REPL.REPLBackend) at /home/carlo/julia/julia-1.4.0/share/julia/stdlib/v1.4/REPL/src/REPL.jl:86
 [26] run_backend(::REPL.REPLBackend) at /home/carlo/.julia/packages/Revise/C272c/src/Revise.jl:1075
 [27] top-level scope at none:0
CarloLucibello commented 4 years ago

Probably we need an adjoint for the StructArray constructor

cossio commented 4 years ago

I think so. The issue seems to be that, since StructArray <: AbstractArray, Zygote is using the differentiation rules for an array. Rather we want to treat StructArray as a struct, with array fields.

cossio commented 4 years ago

Probably we need an adjoint for the StructArray constructor

I gave this a try:

@adjoint function (::Type{T})(t::Tuple) where {T<:StructArray}
    back(Δ) = (nothing,Base.tail(values(Δ))...)
    return T(t), back
end

But now with this example:

struct A
    x::Float64; y::Float64
end
function f(X,Y)
    S = StructArray{A}((X,Y))
    return sum(S.x) + sum(S.y)
end
@show gradient(f, randn(2), randn(2))

the returned gradient is (nothing, nothing). It should be all ones.

Moreover, for some reason the back function defined above gets called with Δ = (fieldarrays = nothing, y = [1.0, 1.0]), so that the gradient in x is already lost!

cossio commented 4 years ago

Finally I got something working:

Zygote.@adjoint function (::Type{T})(t::Tuple) where {T<:StructArray}
    back(Δ::NamedTuple) = (values(Δ),)
    return T(t), back
end

Zygote.@adjoint function Zygote.literal_getproperty(sa::StructArray, ::Val{key}) where {key}
    key::Symbol # only support this
    result = getproperty(sa, key)
    function back(Δ)
        z = (; (p => zero(getproperty(sa, p)) for p in propertynames(sa))...)
        return (merge(z, (; key => Δ)), nothing)
    end
    return result, back
end

struct A
    x::Float64; y::Float64
end
function f(X,Y)
    S = StructArray{A}((X,Y))
    return sum(S.x) + sum(S.y)
end
@show gradient(f, randn(2), randn(2))

Should this be in Zygote?

cossio commented 4 years ago

Fixed in https://github.com/cossio/ZygoteStructArrays.jl.

But leaving this open, since in principle Zygote should just work and differentiate this without needing a rule definition, right?

cossio commented 2 years ago

Update: https://github.com/cossio/ZygoteStructArrays.jl doesn't work with recent versions of StructArrays, due to breaking changes.

I think it would be nice to come up with a solution based on ChainRules.

cossio commented 2 years ago

Now I am getting this error:

julia> using Zygote, StructArrays

julia> struct A
           x::Float64
       end

julia> function f(X)
           S = StructArray{A}((X,))
           sum(S.x)
       end
f (generic function with 1 method)

julia> gradient(f, randn(2))
ERROR: ArgumentError: type does not have a definite number of fields
Stacktrace:
  [1] fieldcount(t::Any)
    @ Base ./reflection.jl:764
  [2] fieldnames(t::DataType)
    @ Base ./reflection.jl:185
  [3] #s72#217
    @ ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:220 [inlined]
  [4] var"#s72#217"(::Any, x::Any)
    @ Zygote ./none:0
  [5] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
  [6] grad_mut(x::Type)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:262
  [7] grad_mut(cx::Zygote.Context, x::Type)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:269
  [8] (::Zygote.var"#back#222"{:parameters, Zygote.Context, DataType, Core.SimpleVector})(Δ::Tuple{Nothing})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:233
  [9] (::Zygote.var"#1765#back#223"{Zygote.var"#back#222"{:parameters, Zygote.Context, DataType, Core.SimpleVector}})(Δ::Tuple{Nothing})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [10] Pullback
    @ ./Base.jl:37 [inlined]
 [11] (::typeof(∂(getproperty)))(Δ::Tuple{Nothing})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [12] Pullback
    @ ./tuple.jl:308 [inlined]
 [13] (::typeof(∂(tuple_type_tail)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/StructArrays/bekT9/src/structarray.jl:32 [inlined]
 [15] (::typeof(∂(index_type)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/StructArrays/bekT9/src/structarray.jl:29 [inlined]
 [17] (::typeof(∂(index_type)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/StructArrays/bekT9/src/structarray.jl:24 [inlined]
 [19] (::typeof(∂(StructVector{A, NamedTuple{(:x,), Tuple{Vector{Float64}}}})))(Δ::NamedTuple{(:components,), Tuple{NamedTuple{(:x,), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}}}})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/.julia/packages/StructArrays/bekT9/src/structarray.jl:94 [inlined]
 [21] (::typeof(∂(StructArray{A})))(Δ::NamedTuple{(:components,), Tuple{NamedTuple{(:x,), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}}}})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [22] Pullback
    @ ./REPL[3]:2 [inlined]
 [23] (::typeof(∂(f)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [24] (::Zygote.var"#57#58"{typeof(∂(f))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:41
 [25] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:76
 [26] top-level scope
    @ REPL[4]:1
cossio commented 2 years ago

Both examples here are now working on Julia 1.8 with:

(@v1.8) pkg> st
Status `~/.julia/environments/v1.8/Project.toml`
  [09ab397b] StructArrays v0.6.12
  [e88e6eb3] Zygote v0.6.47

Not sure what was the fix. Closing.

cossio commented 2 years ago

Any suggestions on how this case can be tested during CI? I'm not sure we want to add StructArrays as a test dependency, or yes?

ToucheSir commented 2 years ago

If you can bisect the last Zygote version that wasn't working, creating a stripped-down type that raises the same error could work. However, I suspect the fix may have come from changes on the StructArrays side. If that's true, then testing fro AD compat ought to happen there and not here if it's not done already.