FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 210 forks source link

Problems with variable indirect use #946

Open RainerHeintzmann opened 3 years ago

RainerHeintzmann commented 3 years ago

We are trying to write an rrule for a custom array class, such that Zygote can differentiate through it, but are stuck due to an error about a missing adjoint for a constructor. This may well be a user error, but it could also be a problem of Zygote. Any help is appreciated!

using ChainRulesCore

struct Example{T,N,F} <: AbstractArray{T,N} where F
    sz::NTuple{N, Int}
    f::F
end

function Base.getindex(a::Example{T,N,F}, idx::Vararg{B,N}) where {T,N,F,B}
    a.f(idx)
end
Base.size(e::Example) = e.sz

function ChainRulesCore.rrule(::typeof(Example{Float64,2,F}),
    sz::NTuple{N, Int},
    gen::F,
    ) where {F,N}
    function IFA_pullback(ΔΩ)
        @show outer = ΔΩ  
        @show inner = Example{Float64,2,typeof(gen)}(sz, gen)
        ∂gen = outer .* inner # wrap in @thunk()
        @show  ∂gen
        return (NO_FIELDS, NO_FIELDS, ∂gen) # why do we need four here?
    end
    Ω = Example{Float64,2,typeof(gen)}(sz, gen)
    return (Ω, IFA_pullback)
end

This code seems to generally work fine for using the error, but the point is the needed ability to differentiate wrt a variable used in the innermost function. The code using this definitions, which then causes the error:

c(a) = begin
    g(idx)= idx[1]*idx[2]*a
    sum(Example{Float64,2,typeof(g)}((3,3),g))
end

using Zygote
gradient(c, 2)

The error looks like this:

julia> include("Scratch_05_GradientTest_Chain_.jl")
outer = ΔΩ = 3×3 Fill{Int64}: entries equal to 1
inner = Example{Float64, 2, typeof(gen)}(sz, gen) = [2 4 6; 4 8 12; 6 12 18]
∂gen = [2.0 4.0 6.0; 4.0 8.0 12.0; 6.0 12.0 18.0]
ERROR: Need an adjoint for constructor var"#g#6"{Int64}. Gradient is of type Matrix{Float64}
Stacktrace:
 [1] error(s::String)  
   @ Base .\error.jl:33
 [2] (::Zygote.Jnew{var"#g#6"{Int64}, Nothing, false})(Δ::Matrix{Float64})
   @ Zygote ~\.julia\packages\Zygote\RxTZu\src\lib\lib.jl:314
 [3] (::Zygote.var"#1720#back#194"{Zygote.Jnew{var"#g#6"{Int64}, Nothing, false}})(Δ::Matrix{Float64})
   @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [4] Pullback
   @ ~\Documents\Programming\Julia\Development\Scratch_05_GradientTest_Chain_.jl:29 [inlined]
 [5] (::Zygote.Pullback{Tuple{typeof(c), Int64}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.var"#1566#back#125"{typeof(identity)}, Zygote.var"#1720#back#194"{Zygote.Jnew{var"#g#6"{Int64}, Nothing, false}}, Zygote.ZBack{var"#IFA_pullback#5"{Tuple{Int64, Int64}, var"#g#6"{Int64}}}, Zygote.var"#2646#back#601"{Zygote.var"#597#599"{Example{Float64, 2, var"#g#6"{Int64}}}}}})(Δ::Int64)
   @ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface2.jl:0
 [6] (::Zygote.var"#41#42"{Zygote.Pullback{Tuple{typeof(c), Int64}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.var"#1566#back#125"{typeof(identity)}, Zygote.var"#1720#back#194"{Zygote.Jnew{var"#g#6"{Int64}, Nothing, false}}, Zygote.ZBack{var"#IFA_pullback#5"{Tuple{Int64, Int64}, var"#g#6"{Int64}}}, Zygote.var"#2646#back#601"{Zygote.var"#597#599"{Example{Float64, 2, var"#g#6"{Int64}}}}}}})(Δ::Int64)
   @ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface.jl:41
 [7] gradient(f::Function, args::Int64)
   @ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface.jl:59
 [8] top-level scope
   @ ~\Documents\Programming\Julia\Development\Scratch_05_GradientTest_Chain_.jl:34

Something similar happens, if you place the variable a right behind sum( in function c.

DhairyaLGandhi commented 3 years ago

For Zygote, it would be better to use Composite. I would try wot use @adjoint here

RainerHeintzmann commented 3 years ago

Thanks for the great hint. Can you point me to an example or a documentation on Composite and canonicalize()? The documentation of these functions did not allow me to figure out how to define these correctly. Presumably one needs to define struct Composite and then call canonicalize with an example of such a struct?

mzgubic commented 3 years ago

You can have a look at the ChainRules documentation, and for examples see the ChainRules package.

MikeInnes commented 3 years ago

We've discussed this over email but I thought the short version might as well be recorded here. The reason for the error is that a matrix, ∂gen is used as the gradient of a closure g. When it comes to unpacking ∂gen to get the gradient of a, Zygote doesn't know what to do. In this example you'd want to use Zygote._pullback inside the rule to get the gradient of g.

devmotion commented 3 years ago

Another problem here seems to be that you define the rrule for ::typeof(Example{...}) instead of ::Type{Example{...}} or ::Type{<:Example{...}}: https://juliadiff.org/ChainRulesCore.jl/previews/PR331/writing_good_rules.html#Use-Type{T},-not-typeof(T),-to-define-rules-for-constructors (it's not mentioned in the official documentation yet but only part of a PR).

RainerHeintzmann commented 3 years ago

Thanks for picking this up! Yet the error message remains the same...

RainerHeintzmann commented 3 years ago

This is my attempt trying to use the suggestion of @MikeInnes :

function ChainRulesCore.rrule(::Type{Example{Float64,2,F}}, 
    sz::NTuple{N, Int},
    gen::F,
    ) where {F,N}
    val_grad(x) = Zygote._pullback(gen, x)[2](1.0) 
    gradgen(x) = val_grad(x)[1][:a] 
    function IFA_pullback(ΔΩ)    
        inner = Example{Float64,2,typeof(gradgen)}(sz, gradgen) 
        ∂gen = ΔΩ .*  inner 
        @show ∂gen
        return (NO_FIELDS,NO_FIELDS,∂gen) 
    end
    Ω = Example{Float64,2,typeof(gen)}(sz, gen)
    return (Ω, IFA_pullback)
end

As you see by the @show ∂gen this code seems to get pretty far, yet somehow the output still leaves Zygote stuck:

julia> gradient(c, 2.0)
∂gen = [4.0 8.0 12.0; 4.0 8.0 12.0; 4.0 8.0 12.0]
ERROR: Need an adjoint for constructor var"#g#14"{Float64}. Gradient is of type Matrix{Float64}
Stacktrace:
 [1] error(s::String)
   @ Base .\error.jl:33
 [2] (::Zygote.Jnew{var"#g#14"{Float64}, Nothing, false})(Δ::Matrix{Float64})
   @ Zygote ~\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:314
 [3] (::Zygote.var"#1723#back#196"{Zygote.Jnew{var"#g#14"{Float64}, Nothing, false}})(Δ::Matrix{Float64})
   @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [4] Pullback
   @ ~\Documents\Programming\Julia\Development\TestingZygote.jl:63 [inlined]
 [5] (::typeof(∂(c)))(Δ::Float64)
   @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [6] (::Zygote.var"#41#42"{typeof(∂(c))})(Δ::Float64)
   @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:41
 [7] gradient(f::Function, args::Float64)
   @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:59
 [8] top-level scope
MikeInnes commented 3 years ago

Again though, you're giving a matrix ∂gen as the gradient for the closure g, so it's the same issue as before. The gradient (for this specific g) should be a named tuple of the form (a = da::Real,); that's something the pullback rule you've defined has to get right.

I suspect the right gradient here would be (a = sum(∂gen),), but that would only work for this specific closure, since others might have more than one capture or call it something other than a. So :a shouldn't appear in the code.

Instead, you want to do something like broadcast the pullback of g over ΔΩ. That gets you a matrix of named tuples, which you can sum with Zygote.accum_sum (which is like sum but supports named tuples).

RainerHeintzmann commented 3 years ago

Thanks @MikeInnes, for this hint. It took me ages to understand that not the returned Tuple needs to be a named tuple but the third of its elements. As far as I can see, there is no need to involve sum or Zygote.accum_sum but its useful to know that they exist. Here is now an implementation, which should hopefully also work for slightly more general cases:

using ChainRulesCore
using Zygote

struct Example{T,N,F} <: AbstractArray{T,N} where F
    sz::NTuple{N, Int}
    f::F
end

function Base.getindex(a::Example{T,N,F}, idx::Vararg{B,N}) where {T,N,F,B}
    a.f(idx)
end
Base.size(e::Example) = e.sz

function ChainRulesCore.rrule(::Type{Example{Float64,2,F}}, 
    sz::NTuple{N, Int},
    gen::F,
    ) where {F,N}
    val_grad(idx) = Zygote._pullback(gen, idx)[2](1.0) # 1.0 is only the seed
    mySymbols = keys(val_grad(sz)[1])
    gradgen(idx) = val_grad(idx)[1] 
    function IFA_pullback(ΔΩ)   
        Fcts = ((idx)-> val_grad(idx)[1][aSymbol] for aSymbol in mySymbols)
        TupleVals = (ΔΩ .* Example{Float64,2,typeof(Fun)}(sz, Fun) for Fun in Fcts)
        ∂gen = NamedTuple{mySymbols}(TupleVals)
        return (NO_FIELDS, NO_FIELDS, ∂gen) 
    end
    Ω = Example{Float64,2,typeof(gen)}(sz, gen)
    return (Ω, IFA_pullback)
end

c(a) = begin
    g(idx)= idx[1] + idx[2] *a*a
    myarr = Example{Float64,2,typeof(g)}((3,3),g)  # 3,3 refers to size
    sum(myarr)  
end

The output looks like this:

julia> gradient(c, 2.0)
([4.0 8.0 12.0; 4.0 8.0 12.0; 4.0 8.0 12.0],)

Pheew. That took longer than planned ;-)

RainerHeintzmann commented 3 years ago

Does anyone know if Zygote._pullback(gen, idx)[2](1.0) can be avoided or replaced with a function in ChainRulesCore? It would be nice to avoid the dependence of our package on Zygote.

devmotion commented 3 years ago

Only if you know the differential explicitly or if there exists an rrule that you can call. Currently, ChainRules does not support calling back into the AD system (such as Zygote): https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68

mzgubic commented 3 years ago

The plan is to allow this by JuliaCon, so watch the issue @devmotion posted in case you will find this useful