FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 209 forks source link

Constructor adjoints ignored when applied via convert? #908

Open doddgray opened 3 years ago

doddgray commented 3 years ago

Why does this fail

julia> function f3(x)::SMatrix{2,2,Float64,4}
           [ x x^2; √x sin(x) ]
       end
julia> Zygote.gradient(x->sum(f3(x)),3.3)  #  ERROR: Need an adjoint for constructor SMatrix{2, 2, Float64, 4}. Gradient is of type FillArrays.Fill{...

while the following alternatives work?

julia> f2(x) = sum(SMatrix{2,2,Float64,4}(x,x^2,√x,sin(x)))
julia> Zygote.gradient(f2,3.3)  # (6.887761171372725,)
julia> f4(x) = SMatrix{2,2,Float64,4}([ x x^2; √x sin(x) ])
julia> Zygote.gradient(x->sum(f4(x)),3.3)   #  (6.887761171372725,)

I see the same behavior with SMatrix constructor adjoints defined using ChainRulesCore.rrule or using Zygote.@adjoint as, respectively

ChainRulesCore.rrule(T::Type{<:SMatrix}, x::AbstractMatrix) = ( T(x), dv -> (NO_FIELDS, dv) )
ChainRulesCore.rrule(T::Type{<:SMatrix}, xs::Number...) = ( T(xs...), dv -> (NO_FIELDS, dv...) )

or

@Zygote.adjoint (T::Type{<:SMatrix})(xs::Number...) = T(xs...), dv -> (nothing, dv...)
@Zygote.adjoint (T::Type{<:SMatrix})(x::AbstractMatrix) = T(x), dv -> (nothing, dv)

It seems like function defs with return-type specifications cause Zygote missing-constructor-adjoint errors even when the constructor adjoint/rrule is defined, but I could be misunderstanding something more basic.

Per a helpful suggestion from @oxinabox on the Julia/#autodiff Slack channel, I also tried writing a rule with all of the type parameter specified:

@Zygote.adjoint (T::Type{SMatrix{2,2,Float64,4}})(x::AbstractMatrix) = T(x), dv -> (nothing, dv)

but this didn't seem to have any effect. Is there some reason convert would be missing custom rules? What's a programmer to do?

mcabbott commented 3 years ago

Shorter way to get the same error:

julia> gradient(x -> sum(convert(SMatrix{2,2,Float64,4}, x)), rand(2,2))
ERROR: Need an adjoint for constructor SMatrix{2, 2, Float64, 4}. Gradient is 
of type FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}

julia> Zygote.@adjoint convert(T::Type{<:SMatrix}, x::AbstractMatrix) = T(x), dv -> (nothing, dv)

julia> gradient(x -> sum(convert(SMatrix{2,2,Float64,4}, x)), rand(2,2))
(2×2 Fill{Float64}: entries equal to 1.0,)

julia> Zygote.gradient(x->sum(f3(x)),3.3)  
(6.887761171372725,)

Not entirely sure that defining an adjoint for convert can't be avoided, though.

doddgray commented 3 years ago

thanks for taking a look at this and giving the example convert adjoint rule. Based on the different lines of StaticArrays/src/convert.jl referenced in the stack traces from the errors handled and not-handled by the constructor rules, it seems related to static sizing?

Using the same f2(x) and f3(x) defs as above:


julia> f2(x) = sum(SMatrix{2,2,Float64,4}(x,x^2,√x,sin(x)))
julia> gradient(f2,3.3)   # before SMatrix constructor rule definition 
ERROR: Need an adjoint for constructor SMatrix{2, 2, Float64, 4}. Gradient is of type FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{SMatrix{2, 2, Float64, 4}, Nothing, false})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:314
  [3] (::Zygote.var"#1729#back#202"{Zygote.Jnew{SMatrix{2, 2, Float64, 4}, Nothing, false}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ~/.julia/packages/StaticArrays/w5a7P/src/SArray.jl:23 [inlined]
  [5] (::typeof(∂(SMatrix{2, 2, Float64, 4})))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/StaticArrays/w5a7P/src/convert.jl:4 [inlined]
  [7] Pullback
    @ ./REPL[6]:1 [inlined]
  [8] (::typeof(∂(f2)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [9] (::Zygote.var"#43#44"{typeof(∂(f2))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [10] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [11] top-level scope
    @ REPL[10]:1

julia> ChainRulesCore.rrule(T::Type{<:SMatrix}, x::AbstractMatrix) = ( T(x), dv -> (NO_FIELDS, dv) )
julia> ChainRulesCore.rrule(T::Type{<:SMatrix}, xs::Number...) = ( T(xs...), dv -> (NO_FIELDS, dv...) )
julia> Zygote.refresh()
julia> gradient(f2,3.3)
(6.887761171372725,)
julia> function f3(x)::SMatrix{2,2,Float64,4}
                 [ x x^2; √x sin(x) ]
          end
julia> Zygote.gradient(x->sum(f3(x)),3.3)
ERROR: Need an adjoint for constructor SMatrix{2, 2, Float64, 4}. Gradient is of type FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{SMatrix{2, 2, Float64, 4}, Nothing, false})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:314
  [3] (::Zygote.var"#1729#back#202"{Zygote.Jnew{SMatrix{2, 2, Float64, 4}, Nothing, false}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ~/.julia/packages/StaticArrays/w5a7P/src/SArray.jl:23 [inlined]
  [5] (::typeof(∂(SMatrix{2, 2, Float64, 4})))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ./compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/StaticArrays/w5a7P/src/convert.jl:36 [inlined]
  [7] (::typeof(∂(_convert)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ./compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/StaticArrays/w5a7P/src/convert.jl:33 [inlined]
  [9] Pullback
    @ ./REPL[13]:1 [inlined]
 [10] (::typeof(∂(#7)))(Δ::Float64)
    @ Zygote ./compiler/interface2.jl:0
 [11] (::Zygote.var"#43#44"{typeof(∂(#7))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [12] gradient(f::Function, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [13] top-level scope
    @ REPL[13]:1

julia> gradient(x -> sum(convert(SMatrix{2,2,Float64,4}, x)), rand(2,2))
# ... same error as with f3(x) above

The former case (handled by rrule) references

https://github.com/JuliaArrays/StaticArrays.jl/blob/8b90b9c2a452557b00a61a94dfa638bc300f8c9c/src/convert.jl#L4

whereas the latter case references

https://github.com/JuliaArrays/StaticArrays.jl/blob/8b90b9c2a452557b00a61a94dfa638bc300f8c9c/src/convert.jl#L36

Maybe this would work if the unroll_tuple(a,l) was splatted in the SA constructor? I'll have to check