JuliaStats / PDMats.jl

Uniform Interface for positive definite matrices of various structures
Other
104 stars 43 forks source link

Problems with Zygote and the `PDMat` constructor #159

Open simsurace opened 2 years ago

simsurace commented 2 years ago

An error is thrown when differentiating a trace of a matrix division with a PDMat:

using LinearAlgebra
using PDMats
using Zygote

function kernel(x)
    return [1. x; x 1.]
end

# PDMat is basically a wrapper for a cholesky decomposition.
# However, using `cholesky` explicitly does not throw any errors:
f(x) = tr(cholesky(kernel(0.1)) \ kernel(x))
Zygote.gradient(x->f(only(x)), [.2]) # works

# Trying to perform the same operation through the `PDMat` constructor fails:
g(x) = tr(PDMat(kernel(0.1)) \ kernel(x))
Zygote.gradient(x->g(only(x)), [.2]) # ERROR
ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{PDMat{Float64, Matrix{Float64}}, Nothing, false})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/lib/lib.jl:324
  [3] (::Zygote.var"#1784#back#228"{Zygote.Jnew{PDMat{Float64, Matrix{Float64}}, Nothing, false}})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ~/.julia/packages/PDMats/ovlmf/src/pdmat.jl:9 [inlined]
  [5] (::typeof(∂(PDMat{Float64, Matrix{Float64}})))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/PDMats/ovlmf/src/pdmat.jl:16 [inlined]
  [7] (::typeof(∂(PDMat)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/PDMats/ovlmf/src/pdmat.jl:19 [inlined]
  [9] (::typeof(∂(PDMat)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[7]:1 [inlined]
 [11] (::typeof(∂(f)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [12] Pullback
    @ ./REPL[8]:1 [inlined]
 [13] (::typeof(∂(#3)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [14] (::Zygote.var"#56#57"{typeof(∂(#3))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:41
 [15] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:76
 [16] top-level scope
    @ REPL[8]:1

This seems like a strange error. I tried to reproduce with my own type, but couldn't:

# The following is basically copy-pasted from master:
struct MyPDMat{T<:Real,S<:AbstractMatrix}
    dim::Int
    mat::S
    chol::Cholesky{T,S}

    MyPDMat{T,S}(d::Int,m::AbstractMatrix{T},c::Cholesky{T,S}) where {T,S} = new{T,S}(d,m,c)
end

function MyPDMat(mat::AbstractMatrix,chol::Cholesky{T,S}) where {T,S}
    d = size(mat, 1)
    size(chol, 1) == d ||
        throw(DimensionMismatch("Dimensions of mat and chol are inconsistent."))
    MyPDMat{T,S}(d, convert(S, mat), chol)
end

MyPDMat(mat::AbstractMatrix) = MyPDMat(mat, cholesky(mat))

Base.:\(a::MyPDMat, x::AbstractVecOrMat) = cholesky(a) \ x
LinearAlgebra.cholesky(a::MyPDMat) = a.chol

h(x) = tr(MyPDMat(kernel(0.1)) \ kernel(x))
Zygote.gradient(x->h(only(x)), [.2]) # works

BTW, all of these functions can be differentiated with ForwardDiff.

devmotion commented 2 years ago

This seems like a strange error.

Doesn't help but I've seen these quite often.

I tried to reproduce with my own type, but couldn't:

A major difference is that your type is not a subtype of AbstractMatrix, and hence defaults for AbstractMatrix in Zygote and ChainRules do not affect it.

devmotion commented 2 years ago

The usual approach for fixing these errors is defining an rrule or a projector with CR, as Will discussed in the linked PR.

simsurace commented 2 years ago

This is a sure sign that I don't actually understand how Zygote works. Where in the call stack would it matter whether or not MyPDMat is a subtype of AbstractMatrix? Because the only way this has any bearing on the result of h is through cholesky, which has a method for MyPDMat anyway.

EDIT: In other words, the fact that Zygote can differentiate h, and it suddenly can't just because we add extra information through making MyPDMat a subtype of AbstractMatrix, which seems to be irrelevant for the specific function call, is mysterious to me.

simsurace commented 2 years ago

Making it a subtype of AbstractMatrix indeed makes it fail:

struct MyOtherPDMat{T<:Real,S<:AbstractMatrix} <: AbstractMatrix{T}
    dim::Int
    mat::S
    chol::Cholesky{T,S}

    MyOtherPDMat{T,S}(d::Int,m::AbstractMatrix{T},c::Cholesky{T,S}) where {T,S} = new{T,S}(d,m,c)
end

function MyOtherPDMat(mat::AbstractMatrix,chol::Cholesky{T,S}) where {T,S}
    d = size(mat, 1)
    size(chol, 1) == d ||
        throw(DimensionMismatch("Dimensions of mat and chol are inconsistent."))
    MyOtherPDMat{T,S}(d, convert(S, mat), chol)
end

MyOtherPDMat(mat::AbstractMatrix) = MyOtherPDMat(mat, cholesky(mat))

Base.:\(a::MyOtherPDMat, x::AbstractVecOrMat) = cholesky(a) \ x
LinearAlgebra.cholesky(a::MyOtherPDMat) = a.chol

h(x) = tr(MyOtherPDMat(kernel(0.1)) \ kernel(x))
Zygote.gradient(x->h(only(x)), [.2]) # ERROR
ERROR: MethodError: no method matching size(::MyOtherPDMat{Float64, Matrix{Float64}})
Closest candidates are:
  size(::AbstractArray{T, N}, ::Any) where {T, N} at ~/.julia/juliaup/julia-1.7.2+0~x64/share/julia/base/abstractarray.jl:42
  size(::Union{QR, LinearAlgebra.QRCompactWY, QRPivoted}) at ~/.julia/juliaup/julia-1.7.2+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/qr.jl:567
  size(::Union{QR, LinearAlgebra.QRCompactWY, QRPivoted}, ::Integer) at ~/.julia/juliaup/julia-1.7.2+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/qr.jl:566
  ...
Stacktrace:
  [1] axes
    @ ./abstractarray.jl:95 [inlined]
  [2] axes(A::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.7.2+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/adjtrans.jl:175
  [3] has_offset_axes(A::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}})
    @ Base ./abstractarray.jl:105
  [4] _tuple_any(f::typeof(Base.has_offset_axes), tf::Bool, a::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}}, b::Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Base ./tuple.jl:516
  [5] _tuple_any(f::Function, t::Tuple{Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}}, Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}})
    @ Base ./tuple.jl:513
  [6] has_offset_axes(::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}}, ::Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Base ./abstractarray.jl:107
  [7] require_one_based_indexing(::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}}, ::Vararg{Any})
    @ Base ./abstractarray.jl:110
  [8] \(A::Adjoint{Float64, MyOtherPDMat{Float64, Matrix{Float64}}}, B::Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.7.2+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/generic.jl:1129
  [9] (::Zygote.var"#752#753"{MyOtherPDMat{Float64, Matrix{Float64}}, Matrix{Float64}, Matrix{Float64}})(Z̄::Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/lib/array.jl:494
 [10] (::Zygote.var"#3058#back#754"{Zygote.var"#752#753"{MyOtherPDMat{Float64, Matrix{Float64}}, Matrix{Float64}, Matrix{Float64}}})(Δ::Diagonal{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [11] Pullback
    @ ./REPL[21]:1 [inlined]
 [12] (::typeof(∂(h)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [13] Pullback
    @ ./REPL[22]:1 [inlined]
 [14] (::typeof(∂(#7)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [15] (::Zygote.var"#56#57"{typeof(∂(#7))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:41
 [16] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:76
 [17] top-level scope
    @ REPL[22]:1
simsurace commented 2 years ago

The usual approach for fixing these errors is defining an rrule or a projector with CR, as Will discussed in the linked PR.

Which is the function that needs an rrule? Is it

PDMat(mat::AbstractMatrix) = PDMat(mat, cholesky(mat))

?

oxinabox commented 2 years ago

If it doesn't happen for your own type that doesn't subtype AbstractMatrix then there should be a solution that involved opting out of some problematic rrule. https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/superpowers/opt_out.html

simsurace commented 2 years ago

Thanks, this sounds great! Any tips on how to find out efficiently which rrule is problematic?

simsurace commented 2 years ago

I wasn't able to figure out which rrule to opt out of. However, I came up with an rrule that allows me to differentiate through the rrule interface directly.

I started with det(PDMats(x)) because it is throwing the same error as the more complicated example involving matrix division.

However, Zygote does not seem to recognize my rrule and still complains about a missing adjoint:

using LinearAlgebra
using PDMats
using Zygote

x = [1. 0.2; 0.2 1.]
y = [1. 0.1; 0.1 1.]

Zygote.gradient(logdet ∘ PDMat, x) |> only # works
Zygote.gradient(det ∘ PDMat, x) |> only # ERROR
# ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64}

# Since `logdet` has an overload for `PDMat` and `det` doesn't, 
# and `logdet` above works while `det` doesn't, try to add an overload
# for `det` and see what happens:
LinearAlgebra.det(A::PDMat) = det(A.chol)
Zygote.gradient(det ∘ PDMat, x) |> only
# ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64}

# While the overload may be useful for efficiency, it does not solve the AD issue
# Next, we will try to define an rrule
using ChainRules, ChainRulesCore

# rrule draft for constructor of PDMat
# This is probably not completely correct, but should be a good start
function ChainRulesCore.rrule(::Type{PDMat}, mat)
    chol, chol_pullback = rrule(cholesky, mat)
    y = PDMat(mat, chol)
    function PDMat_pullbackCR(m̄at::AbstractMatrix)
        @info "Using CR for PDMat, AbstractMatrix tangent"
        return NoTangent(), m̄at
    end
    function PDMat_pullbackCR(m̄at::Tangent)
        @info "Using CR for PDMat, Tangent type"
        return NoTangent(), chol_pullback(m̄at.chol)
    end
    return y, PDMat_pullbackCR
end

# Perform the individual forward and backward steps manually:
a, a_pullback = rrule(PDMat, x)
b, b_pullback = rrule(det, a)

b̄ = 1.
_, ā = b_pullback(b̄)
_, x̄ = a_pullback(ā)

# Compare to the result without the `PDMat` wrapper
unthunk(x̄) ≈ Zygote.gradient(det, x) |> only # true
Zygote.gradient(det ∘ PDMat, x) |> only
# ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64}
oxinabox commented 2 years ago

You often need to do a Zygote.refresh() to get it to pick up new rrules after it been used once.

theogf commented 2 years ago

Just tried this code, and Zygote.gradient(det ∘ PDMat, x) works but then Zygote.gradient(logdet ∘ PDMat, x) fails...

theogf commented 2 years ago

Interestingly ForwardDiff also returns the wrong thing if x is not a PDMat already:

julia> ForwardDiff.gradient(det ∘ PDMat, x)
2×2 Matrix{Float64}:
 1.0  -0.4
 0.0   1.0

julia> ForwardDiff.gradient(det, PDMat(x))
2×2 Matrix{Float64}:
  1.0  -0.2
 -0.2   1.0
julia> ForwardDiff.gradient(det, x)
2×2 Matrix{Float64}:
  1.0  -0.2
 -0.2   1.0
devmotion commented 2 years ago

I started with det(PDMats(x)) because it is throwing the same error as the more complicated example involving matrix division.

Probably one should restrict rrule(::typeof(det), ...) in the same way as the one for logdet: https://github.com/JuliaDiff/ChainRules.jl/pull/245

Regardless of AD, I think it would be useful to add definitions of det(::AbstractPDMat) = .... since otherwise these calls will fall back to the generic definitions based on the LU decomposition in LinearAlgebra: https://github.com/JuliaLang/julia/blob/6e061322438f13c6548200f115f3c31b20860a30/stdlib/LinearAlgebra/src/generic.jl#L1598-L1604

Probably logdet(::PDMat) does not error because Zygote defines an adjoint for logdet(::Cholesky) (should be moved to ChainRules I guess together with a rule for det(::Cholesky)): https://github.com/FluxML/Zygote.jl/blob/a392eabdc0217f2f34d77ce19d4167c3cd4abbcf/src/lib/array.jl#L744-L748 However, similar to the ForwardDiff example it seems the derivative is wrong since it does not return a Hermitian but a triangular matrix as gradient.

I think the right approach would be to add a projection mechanism for PDMat and rrules for the constructor similar to the one for Hermitian and Symmetric matrices (https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2d75b4be102bb41ba3ac6df6dec8bb9617b20f0f/src/projection.jl#L425-L451 and https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/symmetric.jl#L5-L92).

simsurace commented 2 years ago

Interestingly ForwardDiff also returns the wrong thing if x is not a PDMat already:

I think this would be a great test to do for all the functions overloaded for PDMat. Basically, the wrapper should not change any derivatives wrt. x.

devmotion commented 2 years ago

Regardless of AD, I think it would be useful to add definitions of det(::AbstractPDMat) = .... since otherwise these calls will fall back to the generic definitions based on the LU decomposition in LinearAlgebra: https://github.com/JuliaLang/julia/blob/6e061322438f13c6548200f115f3c31b20860a30/stdlib/LinearAlgebra/src/generic.jl#L1598-L1604

I opened #161.

devmotion commented 2 years ago

The ForwardDiff issue is not related to PDMats:

julia> using PDMats, ForwardDiff, LinearAlgebra

julia> x = [1. 0.2; 0.2 1.];

julia> ForwardDiff.gradient(det ∘ PDMat, x)
2×2 Matrix{Float64}:
 1.0  -0.4
 0.0   1.0

julia> ForwardDiff.gradient(det ∘ cholesky, x)
2×2 Matrix{Float64}:
 1.0  -0.4
 0.0   1.0

From the perspective of det etc., PDMat does not wrap x but cholesky(x), so the comparison with ForwardDiff.gradient(det, x) is not correct. Thus the incorrect derivative is caused by cholesky and actually expected.

simsurace commented 2 years ago

So cholesky is not defined for non-symmetric x, which means that a gradient step will lead to the subspace being left. This is an instance of the gradient being correct as a differential, i.e. for any choice of tangent vector the differential applied to that vector gives the correct result, but it is not a tangent vector itself. Should I open an issue in DiffRules.jl?

simsurace commented 2 years ago

Especially because using a Symmetric wrapper fails:

julia> ForwardDiff.gradient(logdet ∘ PDMat, Symmetric(x))
ERROR: ArgumentError: Cannot set a non-diagonal index in a symmetric matrix

whereas it works for Zygote:

julia> Zygote.gradient(logdet ∘ PDMat, Symmetric(x)) |> only
2×2 Symmetric{Float64, Matrix{Float64}}:
  1.04167   -0.208333
 -0.208333   1.04167
devmotion commented 2 years ago

Actually, I don't think there's anything wrong with the derivatives of ForwardDiff and cholesky, the forward-mode sensitivities of cholesky are correct (compare e.g. with https://arxiv.org/abs/1602.07527). It's just that ForwardDiff.gradient(det \circ cholesky, x) is fundamentally different from ForwardDiff.gradient(det, x): the first one assumes that x AND the sensitivities/perturbations/... dx are symmetric matrices whereas the second one does not use any of these assumptions.

devmotion commented 2 years ago

Especially because using a Symmetric wrapper fails:

This is a general issue with ForwardDiff.seed!, not limited to Symmetric and not related to cholesky.

simsurace commented 2 years ago

Yes, as I said I do believe that the differential is correct because it is only defined for a symmetric tangent. I think we can focus on getting the gradients working in Zygote.

devmotion commented 2 years ago

The only things to make it work with Zygote are (copied from above):

Probably one should restrict rrule(::typeof(det), ...) in the same way as the one for logdet: https://github.com/JuliaDiff/ChainRules.jl/pull/245 Zygote defines an adjoint for logdet(::Cholesky) (should be moved to ChainRules I guess together with a rule for det(::Cholesky)): https://github.com/FluxML/Zygote.jl/blob/a392eabdc0217f2f34d77ce19d4167c3cd4abbcf/src/lib/array.jl#L744-L748

I just checked it locally, with these changes also det works. I'll open a PR in ChainRules.