JuliaGaussianProcesses / ApproximateGPs.jl

Approximations for Gaussian processes: sparse variational inducing point approximations, Laplace approximation, ...
https://juliagaussianprocesses.github.io/ApproximateGPs.jl/dev
Other
35 stars 6 forks source link

`_prior_kl` not differentiable for `Centered` with `Zygote` #129

Open theogf opened 2 years ago

theogf commented 2 years ago

Here is simplified view of the problem from @simsurace:

using Distributions
using Zygote

function DKL(par1, par2)
    K1 = [par1[1] par1[2]; par1[2] par1[1]]
    K2 = [par2[1] par2[2]; par2[2] par2[1]]
    return kldivergence(
        MvNormal(K1),
        MvNormal(K2)
    )
end

Zygote.gradient(par2 -> DKL([1., 0.1], par2), [1., 0.1])

With the following error:

ERROR: MethodError: no method matching +(::NamedTuple{(:dim, :mat, :chol), Tuple{Nothing, Nothing, NamedTuple{(:factors, :uplo, :info), Tuple{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Nothing, Nothing}}}}, ::Matrix{Float64})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at ~/julia-1.7.1/share/julia/base/operators.jl:655
  +(::FillArrays.Zeros{T, N}, ::AbstractArray{V, N}) where {T, V, N} at ~/.julia/packages/FillArrays/5Arin/src/fillalgebra.jl:228
  +(::Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_arithmetic.jl:146
  ...
Stacktrace:
  [1] accum(x::NamedTuple{(:dim, :mat, :chol), Tuple{Nothing, Nothing, NamedTuple{(:factors, :uplo, :info), Tuple{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Nothing, Nothing}}}}, y::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:17
  [2] macro expansion
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:27 [inlined]
  [3] accum(x::NamedTuple{(:μ, :Σ), Tuple{Vector{Float64}, NamedTuple{(:dim, :mat, :chol), Tuple{Nothing, Nothing, NamedTuple{(:factors, :uplo, :info), Tuple{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Nothing, Nothing}}}}}}, y::NamedTuple{(:μ, :Σ), Tuple{Nothing, Matrix{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:27
  [4] accum(x::NamedTuple{(:μ, :Σ), Tuple{Nothing, NamedTuple{(:dim, :mat, :chol), Tuple{Nothing, Nothing, NamedTuple{(:factors, :uplo, :info), Tuple{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Nothing, Nothing}}}}}}, y::NamedTuple{(:μ, :Σ), Tuple{Vector{Float64}, Nothing}}, zs::NamedTuple{(:μ, :Σ), Tuple{Nothing, Matrix{Float64}}}) (repeats 2 times)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:22
  [5] Pullback
    @ ~/.julia/packages/Distributions/O4ZJg/src/multivariate/mvnormal.jl:110 [inlined]
  [6] (::typeof(∂(kldivergence)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [7] Pullback
    @ ./REPL[16]:4 [inlined]
  [8] (::typeof(∂(DKL)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [9] Pullback
    @ ./REPL[18]:1 [inlined]
 [10] (::typeof(∂(#5)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#56#57"{typeof(∂(#5))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41
 [12] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76
 [13] top-level scope
    @ REPL[18]:1
theogf commented 2 years ago

The solution from https://github.com/theogf/KLDivergences.jl/blob/main/src/horrible_ad_workaround.jl does not work

theogf commented 2 years ago

I was thinking maybe the easiest (and maybe cheapest!) would be to directly write the rrule for kldivergence

devmotion commented 2 years ago

This seems to be yet another AD issue with PDMats. Maybe about time to add CR to that repo.

devmotion commented 2 years ago

Regardless, it might still be useful and possibly more efficient to add a CR definition for kl_divergence directly.

theogf commented 2 years ago

As usual, I have the forward rule, but I don't have the brain capacity to derive the reverse one.

theogf commented 2 years ago

This seems to be yet another AD issue with PDMats. Maybe about time to add CR to that repo.

I thought there was some friction for this but I cannot find the discussion

EDIT: Ah no! That's in Distances.jl

st-- commented 2 years ago

As usual, I have the forward rule, but I don't have the brain capacity to derive the reverse one.

If you share the forward rule I can see if I manage to sort out the rrule:)

willtebbutt commented 2 years ago

This is an example of one rrule (for some function utilised in kldivergence) returning a Diagonal matrix (a "natural" tangent), and another digging into the PDMat and returning a NamedTuple (the "structural" tangent). You can deduce that this kind of thing is going on from the call to accum with a NamedTuple and a Diagonal -- generally speaking, accum is roughly equivalent to +, so if it's not obvious how to add two types, accum probably won't work for them without manual intervention.

There are two flavours of fix for this kind of problem:

  1. prevent the accum error by ensuring that both tangents are converted to a common type before hitting the call to accum, and
  2. implement accum for the two types in question.

In my view the former is the way to go, and it is what the projection mechanism in CR lets you do.

If I had to guess, I would say that _cov(q) \ _cov(p) is giving the natural, and logdetcov(q) / logdetcov(p) the structural, but you would have to check.

The fix is probably to get PDMats onto CR, as @devmotion suggests, and presumably to implement the projection mechanism for it.

simsurace commented 2 years ago

I tried the terms in kldivergence one by one:

using Distributions
using ForwardDiff
using Zygote

function kernel(x)
    return [1. x; x 1.]
end

d(x) = MvNormal(kernel(x))

kldivergence(d(0.1), d(0.0))

f(x) = kldivergence(d(x), d(0.0))
ForwardDiff.gradient(x->f(only(x)), [.1]) # works
Zygote.gradient(x->f(only(x)), [.1]) # ERROR

g(x) = logdetcov(d(x))
Zygote.gradient(x->g(only(x)), [.1]) # works

g(x) = sqmahal(d(x), zeros(2))
Zygote.gradient(x->g(only(x)), [.1]) # works

g(x) = length(d(x))
Zygote.gradient(x->g(only(x)), [.1]) # works

g(x) = tr(cov(d(0.0)) \ cov(d(x)))
Zygote.gradient(x->g(only(x)), [.1]) # ERROR

g(x) = (tr(cov(d(0.0)) \ cov(d(x))) + sqmahal(d(0.0), mean(d(x))) - length(d(x)) + logdetcov(d(0.0)) - logdetcov(d(x))) / 2
Zygote.gradient(x->g(only(x)), [.1]) # ERROR

Interestingly, the last two errors are different than the first:

Full test output

``` ERROR: Need an adjoint for constructor PDMats.PDMat{Float64, Matrix{Float64}}. Gradient is of type Diagonal{Float64, Vector{Float64}} Stacktrace: [1] error(s::String) @ Base ./error.jl:33 [2] (::Zygote.Jnew{PDMats.PDMat{Float64, Matrix{Float64}}, Nothing, false})(Δ::Diagonal{Float64, Vector{Float64}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:324 [3] (::Zygote.var"#1784#back#228"{Zygote.Jnew{PDMats.PDMat{Float64, Matrix{Float64}}, Nothing, false}})(Δ::Diagonal{Float64, Vector{Float64}}) @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [4] Pullback @ ~/.julia/packages/PDMats/mudzk/src/pdmat.jl:9 [inlined] [5] (::typeof(∂(PDMats.PDMat{Float64, Matrix{Float64}})))(Δ::Diagonal{Float64, Vector{Float64}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [6] Pullback @ ~/.julia/packages/PDMats/mudzk/src/pdmat.jl:16 [inlined] [7] (::typeof(∂(PDMats.PDMat)))(Δ::Diagonal{Float64, Vector{Float64}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [8] Pullback @ ~/.julia/packages/PDMats/mudzk/src/pdmat.jl:19 [inlined] [9] (::typeof(∂(PDMats.PDMat)))(Δ::Diagonal{Float64, Vector{Float64}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [10] Pullback @ ~/.julia/packages/Distributions/O4ZJg/src/multivariate/mvnormal.jl:201 [inlined] [11] (::typeof(∂(MvNormal)))(Δ::NamedTuple{(:μ, :Σ), Tuple{Nothing, Diagonal{Float64, Vector{Float64}}}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [12] Pullback @ ~/.julia/packages/Distributions/O4ZJg/src/multivariate/mvnormal.jl:218 [inlined] [13] (::typeof(∂(MvNormal)))(Δ::NamedTuple{(:μ, :Σ), Tuple{Nothing, Diagonal{Float64, Vector{Float64}}}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [14] Pullback @ ~/Documents/projects/n/ai-timeseries-prototypes/autodiff/kldiff.jl:9 [inlined] [15] (::typeof(∂(d)))(Δ::NamedTuple{(:μ, :Σ), Tuple{Nothing, Diagonal{Float64, Vector{Float64}}}}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [16] Pullback @ ./REPL[40]:1 [inlined] [17] (::typeof(∂(g)))(Δ::Float64) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [18] Pullback @ ./REPL[41]:1 [inlined] [19] (::typeof(∂(#33)))(Δ::Float64) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0 [20] (::Zygote.var"#56#57"{typeof(∂(#33))})(Δ::Float64) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41 [21] gradient(f::Function, args::Vector{Float64}) @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76 [22] top-level scope @ REPL[41]:1 ```

What's going on here?

devmotion commented 2 years ago

All of them are PDMats issues, it seems.

simsurace commented 2 years ago

I was expecting the + issue to show up when I assemble the terms manually. Instead, the last two errors seem to fail at an earlier stage.

simsurace commented 2 years ago

I opened an issue https://github.com/JuliaStats/PDMats.jl/issues/159

st-- commented 1 year ago

Despite https://github.com/JuliaDiff/ChainRules.jl/pull/613 this seems to still be broken :(

devmotion commented 1 year ago

The PDMats issues (https://github.com/JuliaStats/PDMats.jl/issues/159) are not fixed yet.