Open RainerHeintzmann opened 3 years ago
For Zygote, it would be better to use Composite
. I would try wot use @adjoint
here
Thanks for the great hint. Can you point me to an example or a documentation on Composite
and canonicalize()
? The documentation of these functions did not allow me to figure out how to define these correctly. Presumably one needs to define struct Composite
and then call canonicalize
with an example of such a struct?
You can have a look at the ChainRules documentation, and for examples see the ChainRules package.
We've discussed this over email but I thought the short version might as well be recorded here. The reason for the error is that a matrix, ∂gen
is used as the gradient of a closure g
. When it comes to unpacking ∂gen
to get the gradient of a
, Zygote doesn't know what to do. In this example you'd want to use Zygote._pullback
inside the rule to get the gradient of g
.
Another problem here seems to be that you define the rrule
for ::typeof(Example{...})
instead of ::Type{Example{...}}
or ::Type{<:Example{...}}
: https://juliadiff.org/ChainRulesCore.jl/previews/PR331/writing_good_rules.html#Use-Type{T},-not-typeof(T),-to-define-rules-for-constructors (it's not mentioned in the official documentation yet but only part of a PR).
Thanks for picking this up! Yet the error message remains the same...
This is my attempt trying to use the suggestion of @MikeInnes :
function ChainRulesCore.rrule(::Type{Example{Float64,2,F}},
sz::NTuple{N, Int},
gen::F,
) where {F,N}
val_grad(x) = Zygote._pullback(gen, x)[2](1.0)
gradgen(x) = val_grad(x)[1][:a]
function IFA_pullback(ΔΩ)
inner = Example{Float64,2,typeof(gradgen)}(sz, gradgen)
∂gen = ΔΩ .* inner
@show ∂gen
return (NO_FIELDS,NO_FIELDS,∂gen)
end
Ω = Example{Float64,2,typeof(gen)}(sz, gen)
return (Ω, IFA_pullback)
end
As you see by the @show ∂gen
this code seems to get pretty far, yet somehow the output still leaves Zygote stuck:
julia> gradient(c, 2.0)
∂gen = [4.0 8.0 12.0; 4.0 8.0 12.0; 4.0 8.0 12.0]
ERROR: Need an adjoint for constructor var"#g#14"{Float64}. Gradient is of type Matrix{Float64}
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] (::Zygote.Jnew{var"#g#14"{Float64}, Nothing, false})(Δ::Matrix{Float64})
@ Zygote ~\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:314
[3] (::Zygote.var"#1723#back#196"{Zygote.Jnew{var"#g#14"{Float64}, Nothing, false}})(Δ::Matrix{Float64})
@ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
[4] Pullback
@ ~\Documents\Programming\Julia\Development\TestingZygote.jl:63 [inlined]
[5] (::typeof(∂(c)))(Δ::Float64)
@ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
[6] (::Zygote.var"#41#42"{typeof(∂(c))})(Δ::Float64)
@ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:41
[7] gradient(f::Function, args::Float64)
@ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:59
[8] top-level scope
Again though, you're giving a matrix ∂gen
as the gradient for the closure g
, so it's the same issue as before. The gradient (for this specific g
) should be a named tuple of the form (a = da::Real,)
; that's something the pullback rule you've defined has to get right.
I suspect the right gradient here would be (a = sum(∂gen),)
, but that would only work for this specific closure, since others might have more than one capture or call it something other than a
. So :a
shouldn't appear in the code.
Instead, you want to do something like broadcast the pullback of g
over ΔΩ
. That gets you a matrix of named tuples, which you can sum with Zygote.accum_sum
(which is like sum
but supports named tuples).
Thanks @MikeInnes, for this hint. It took me ages to understand that not the returned Tuple
needs to be a named tuple but the third of its elements. As far as I can see, there is no need to involve sum
or Zygote.accum_sum
but its useful to know that they exist. Here is now an implementation, which should hopefully also work for slightly more general cases:
using ChainRulesCore
using Zygote
struct Example{T,N,F} <: AbstractArray{T,N} where F
sz::NTuple{N, Int}
f::F
end
function Base.getindex(a::Example{T,N,F}, idx::Vararg{B,N}) where {T,N,F,B}
a.f(idx)
end
Base.size(e::Example) = e.sz
function ChainRulesCore.rrule(::Type{Example{Float64,2,F}},
sz::NTuple{N, Int},
gen::F,
) where {F,N}
val_grad(idx) = Zygote._pullback(gen, idx)[2](1.0) # 1.0 is only the seed
mySymbols = keys(val_grad(sz)[1])
gradgen(idx) = val_grad(idx)[1]
function IFA_pullback(ΔΩ)
Fcts = ((idx)-> val_grad(idx)[1][aSymbol] for aSymbol in mySymbols)
TupleVals = (ΔΩ .* Example{Float64,2,typeof(Fun)}(sz, Fun) for Fun in Fcts)
∂gen = NamedTuple{mySymbols}(TupleVals)
return (NO_FIELDS, NO_FIELDS, ∂gen)
end
Ω = Example{Float64,2,typeof(gen)}(sz, gen)
return (Ω, IFA_pullback)
end
c(a) = begin
g(idx)= idx[1] + idx[2] *a*a
myarr = Example{Float64,2,typeof(g)}((3,3),g) # 3,3 refers to size
sum(myarr)
end
The output looks like this:
julia> gradient(c, 2.0)
([4.0 8.0 12.0; 4.0 8.0 12.0; 4.0 8.0 12.0],)
Pheew. That took longer than planned ;-)
Does anyone know if Zygote._pullback(gen, idx)[2](1.0)
can be avoided or replaced with a function in ChainRulesCore
? It would be nice to avoid the dependence of our package on Zygote
.
Only if you know the differential explicitly or if there exists an rrule
that you can call. Currently, ChainRules does not support calling back into the AD system (such as Zygote): https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68
The plan is to allow this by JuliaCon, so watch the issue @devmotion posted in case you will find this useful
We are trying to write an
rrule
for a custom array class, such thatZygote
can differentiate through it, but are stuck due to an error about a missing adjoint for a constructor. This may well be a user error, but it could also be a problem ofZygote
. Any help is appreciated!This code seems to generally work fine for using the error, but the point is the needed ability to differentiate wrt a variable used in the innermost function. The code using this definitions, which then causes the error:
The error looks like this:
Something similar happens, if you place the variable
a
right behindsum(
in functionc
.