Open theogf opened 2 years ago
The solution from https://github.com/theogf/KLDivergences.jl/blob/main/src/horrible_ad_workaround.jl does not work
I was thinking maybe the easiest (and maybe cheapest!) would be to directly write the rrule
for kldivergence
This seems to be yet another AD issue with PDMats. Maybe about time to add CR to that repo.
Regardless, it might still be useful and possibly more efficient to add a CR definition for kl_divergence
directly.
As usual, I have the forward rule, but I don't have the brain capacity to derive the reverse one.
This seems to be yet another AD issue with PDMats. Maybe about time to add CR to that repo.
I thought there was some friction for this but I cannot find the discussion
EDIT: Ah no! That's in Distances.jl
As usual, I have the forward rule, but I don't have the brain capacity to derive the reverse one.
If you share the forward rule I can see if I manage to sort out the rrule:)
This is an example of one rrule
(for some function utilised in kldivergence
) returning a Diagonal
matrix (a "natural" tangent), and another digging into the PDMat and returning a NamedTuple
(the "structural" tangent). You can deduce that this kind of thing is going on from the call to accum
with a NamedTuple
and a Diagonal
-- generally speaking, accum
is roughly equivalent to +
, so if it's not obvious how to add two types, accum
probably won't work for them without manual intervention.
There are two flavours of fix for this kind of problem:
accum
error by ensuring that both tangents are converted to a common type before hitting the call to accum
, andaccum
for the two types in question.In my view the former is the way to go, and it is what the projection mechanism in CR lets you do.
If I had to guess, I would say that _cov(q) \ _cov(p)
is giving the natural, and logdetcov(q)
/ logdetcov(p)
the structural, but you would have to check.
The fix is probably to get PDMats
onto CR, as @devmotion suggests, and presumably to implement the projection mechanism for it.
I tried the terms in kldivergence
one by one:
using Distributions
using ForwardDiff
using Zygote
function kernel(x)
return [1. x; x 1.]
end
d(x) = MvNormal(kernel(x))
kldivergence(d(0.1), d(0.0))
f(x) = kldivergence(d(x), d(0.0))
ForwardDiff.gradient(x->f(only(x)), [.1]) # works
Zygote.gradient(x->f(only(x)), [.1]) # ERROR
g(x) = logdetcov(d(x))
Zygote.gradient(x->g(only(x)), [.1]) # works
g(x) = sqmahal(d(x), zeros(2))
Zygote.gradient(x->g(only(x)), [.1]) # works
g(x) = length(d(x))
Zygote.gradient(x->g(only(x)), [.1]) # works
g(x) = tr(cov(d(0.0)) \ cov(d(x)))
Zygote.gradient(x->g(only(x)), [.1]) # ERROR
g(x) = (tr(cov(d(0.0)) \ cov(d(x))) + sqmahal(d(0.0), mean(d(x))) - length(d(x)) + logdetcov(d(0.0)) - logdetcov(d(x))) / 2
Zygote.gradient(x->g(only(x)), [.1]) # ERROR
Interestingly, the last two errors are different than the first:
``` ERROR: Need an adjoint for constructor PDMats.PDMat{Float64, Matrix{Float64}}. Gradient is of type Diagonal{Float64, Vector{Float64}} Stacktrace: [1] error(s::String) @ Base ./error.jl:33 [2] (::Zygote.Jnew{PDMats.PDMat{Float64, Matrix{Float64}}, Nothing, false})(Δ::Diagonal{Float64, Vector{Float64}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:324 [3] (::Zygote.var"#1784#back#228"{Zygote.Jnew{PDMats.PDMat{Float64, Matrix{Float64}}, Nothing, false}})(Δ::Diagonal{Float64, Vector{Float64}}) @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [4] Pullback @ ~/.julia/packages/PDMats/mudzk/src/pdmat.jl:9 [inlined] [5] (::typeof(∂(PDMats.PDMat{Float64, Matrix{Float64}})))(Δ::Diagonal{Float64, Vector{Float64}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [6] Pullback @ ~/.julia/packages/PDMats/mudzk/src/pdmat.jl:16 [inlined] [7] (::typeof(∂(PDMats.PDMat)))(Δ::Diagonal{Float64, Vector{Float64}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [8] Pullback @ ~/.julia/packages/PDMats/mudzk/src/pdmat.jl:19 [inlined] [9] (::typeof(∂(PDMats.PDMat)))(Δ::Diagonal{Float64, Vector{Float64}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [10] Pullback @ ~/.julia/packages/Distributions/O4ZJg/src/multivariate/mvnormal.jl:201 [inlined] [11] (::typeof(∂(MvNormal)))(Δ::NamedTuple{(:μ, :Σ), Tuple{Nothing, Diagonal{Float64, Vector{Float64}}}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [12] Pullback @ ~/.julia/packages/Distributions/O4ZJg/src/multivariate/mvnormal.jl:218 [inlined] [13] (::typeof(∂(MvNormal)))(Δ::NamedTuple{(:μ, :Σ), Tuple{Nothing, Diagonal{Float64, Vector{Float64}}}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [14] Pullback @ ~/Documents/projects/n/ai-timeseries-prototypes/autodiff/kldiff.jl:9 [inlined] [15] (::typeof(∂(d)))(Δ::NamedTuple{(:μ, :Σ), Tuple{Nothing, Diagonal{Float64, Vector{Float64}}}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [16] Pullback @ ./REPL[40]:1 [inlined] [17] (::typeof(∂(g)))(Δ::Float64) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [18] Pullback @ ./REPL[41]:1 [inlined] [19] (::typeof(∂(#33)))(Δ::Float64) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [20] (::Zygote.var"#56#57"{typeof(∂(#33))})(Δ::Float64) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41 [21] gradient(f::Function, args::Vector{Float64}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76 [22] top-level scope @ REPL[41]:1 ```
What's going on here?
All of them are PDMats issues, it seems.
I was expecting the + issue to show up when I assemble the terms manually. Instead, the last two errors seem to fail at an earlier stage.
I opened an issue https://github.com/JuliaStats/PDMats.jl/issues/159
Despite https://github.com/JuliaDiff/ChainRules.jl/pull/613 this seems to still be broken :(
The PDMats issues (https://github.com/JuliaStats/PDMats.jl/issues/159) are not fixed yet.
Here is simplified view of the problem from @simsurace:
With the following error: