stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
751 stars 188 forks source link

Wishart Cholesky parameterization #2705

Closed spinkney closed 2 years ago

spinkney commented 2 years ago

We can add the Cholesky parameterization on both X and Sigma or just on Sigma

In cpp we'd use LDLt so we don't have to add_diag.

  real wishart_cholesky_a_lpdf(matrix X, real n, matrix L) {
    int P = rows(L);
    real lp = (n - P - 1) * sum(log(diagonal(cholesky_decompose(add_diag(X, 1e-7)))));

    // normalizing constant
    // lp -= 0.5 * (n * P) * log2() + lmgamma(P, 0.5 * n);
     lp -= 0.5 * sum(diagonal(mdivide_left_tri_low(L, mdivide_left_tri_low(L, X)')));

    return lp - n * sum(log(diagonal(L)));
  }

   real wishart_cholesky_b_lpdf(matrix Lx, real n, matrix L) {
    int P = rows(L);
    real lp = (n - P - 1) * sum(log(diagonal(Lx)));
    matrix[P, P] LinvLx = mdivide_left_tri_low(L, Lx);

     // normalizing constant
    //  lp -= 0.5 * (n * P) * log2() + lmgamma(P, 0.5 * n);
    lp -= 0.5 * sum(diagonal(crossprod(LinvLx)));

    return lp - n * sum(log(diagonal(L)));
  }

The results equal to what the CholWishart R package has

library(rethinking)
> L <- t(chol(rlkjcorr(1, 8, eta = 0.1)))
> Lx <- t(chol(rlkjcorr(1, 8, eta = 0.1)))

> dWishart(x = tcrossprod(Lx), df = 9, Sigma = tcrossprod(L), log = TRUE)
[1] -443.0111
> wishart_cholesky_a_lpdf(X = tcrossprod(Lx), n = 9, L = L)
[1] -443.0111
> wishart_cholesky_b_lpdf(Lx = Lx, n = 9, L = L)
[1] -443.0111
bob-carpenter commented 2 years ago

Neat. That's really clean and should be OK to autodiff through in a direct C++ translation.

You can simplify even further replacing sum(diagonal(...)) with trace(...), which should have better autodiff properties. Or better yet, can we simplify without a call to crossprod given that we only need the diagonal of the result? The whole crossprod function introduces a lot of autodiff overhead.

I think we only need the "b" version. The important case is where the variant is a Cholesky factor because that's the efficient data structure for covariance matrix parameters.

Ideally, we'd implement analytic derivatives through adjoint-Jacobian product updates. That is, let

lp = wishart_cholesky_b_lpdf(Lx, n, L)

and figure out how to update Lx.adjoint and L.adjoint given lp, Lx, L, and lp.adjoint. It requires an adjoint-Jacobian product, but given the output is univariate, that's just an adjoint-gradient product,

L.adjoint += lp.adjoint * d.lp / d.L
Lx.adjoint += lp.adjoint * d.lp / d.Lx
spinkney commented 2 years ago

I'll see if I can work out those adjoints. As for the crossprod, we can get rid of it because that matrix is lower triangular. In cpp we can use a lower triangular type to store and then a lower triangular multiply.

The updated version would be

   real wishart_cholesky_b_lpdf(matrix Lx, real n, matrix L) {
    int P = rows(L);
    real lp = (n - P - 1) * sum(log(diagonal(Lx)));
    matrix[P, P] LinvLx = mdivide_left_tri_low(L, Lx);

     // normalizing constant
    //  lp -= 0.5 * (n * P) * log2() + lmgamma(P, 0.5 * n);
    lp -= 0.5 * trace(multiply_lower_tri_self_transpose(LinvLx));

    return lp - n * sum(log(diagonal(L)));
  }
bob-carpenter commented 2 years ago

we can use a lower triangular type to store and then a lower triangular multiply.

Thanks @spinkney.

Given that we only need the diagonal elements, it'll be a lot more efficient in memory to do this

for (p in 1:P)
  lp -= 0.5 * dot_product(LinvLx[p, 1:p]);
spinkney commented 2 years ago

@bob-carpenter the b version needs a change of variables if we want to use this directly on Lx.

real wishart_cholesky_b_lpdf(matrix Lx, real n, matrix L) {
    int P = rows(L);
    matrix[P, P] LinvLx = mdivide_left_tri_low(L, Lx);
      // normalizing constant
    real lp = -(P * log2() * (1 - 0.5 * n)  + lmgamma(P, 0.5 * n));

    for (p in 1:P) {
        lp -= 0.5 * dot_self(LinvLx[p, 1:p]) + n * log(L[p, p]) - (n - p) * log(Lx[p, p]);
    }

    return lp;
}

That's from slide 36 of https://www.maxturgeon.ca/w20-stat7200/slides/wishart-distribution.pdf. I believe that the normalizing constant in that presentation is wrong (it should have |\Sigma|^{n/2})

For my notes. Derivatives wrt to the parameters (written in Julia)

using Distributions
using LinearAlgebra
using ForwardDiff
using SpecialFunctions
using StatsFuns
using BenchmarkTools
using LoopVectorization

function jselfdotavx(a)
    s = zero(eltype(a))
    @avx for i ∈ eachindex(a)
        s += a[i] * a[i]
    end
    s
end

function logdettriangle(B::Union{LowerTriangular,UpperTriangular})
    A = parent(B) # using a triangular matrix would fall back to the default loop.
    ld = zero(eltype(A))
    @avx for n ∈ axes(A,1)
        ld += log(A[n,n])
    end
    ld
end

wishart_cholesky_b_lpdf = function(Lx::LowerTriangular, n::Real, L::LowerTriangular)
    P::Int64 = size(L)[1]
    # lp = (n - P - 1) * sum(log.(diag(Lx)))
    LinvLx::LowerTriangular = L \ Lx

    # normalizing constant
    lp = P * logtwo * (1 - 0.5 * n) - logmvgamma(P, 0.5 * n)
    # lp -= n * logdettriangle(L)
   # lp += (n - P - 1) * logdettriangle(Lx)
    for p in 1:P
        lp -= 0.5 * jselfdotavx(LinvLx[p, 1:p]) + n * log(L[p, p]) - (n - p) * log(Lx[p, p])
    end

    return lp 
end

lkj = LKJCholesky(4, 1)
Lx = rand(lkj)
L = rand(lkj)

L = cholesky(Matrix(L)).L
Lx = cholesky(Matrix(Lx)).L

# d/dL
dLx = function(L, Lx)
    LinvLx = L \ Lx;
    return -0.5 * tr( LinvLx * LinvLx')
end

ForwardDiff.gradient(L -> dLx(L, Lx), L)

LinvLx = L \ Lx

LinvLx = L \ Lx
Linvt = inv(L')
mlt_LinvLx = LinvLx * LinvLx'

# calculates
# tril(inv(L') * (LinvLx * LinvLx'))
P = size(L)[1]
jL = zero(L)
s = 0.
for i ∈ 1:P
    jL[i, i] = -df / L[i, i]
    for j ∈ i:P
        jL[j, i] += sum(mlt_LinvLx[j:P, i] .* Linvt[j , j:P])
    end
end

ad_jL = ForwardDiff.gradient(L -> wishart_cholesky_b_lpdf(Lx, df, L), L)

isapprox(jL, ad_jL)

# d/dLx
ForwardDiff.gradient(Lx -> dLx(L, Lx), Lx)

LinvLx = L \ Lx
Linvt = inv(L')

P = size(Lx)[1]
jLx = zero(Lx)
s = 0.
for i ∈ 1:P
    jLx[i, i] = (df - i) / Lx[i, i]
    for j ∈ i:P
        jLx[j, i] += -sum(LinvLx[j:P, i] .* Linvt[j , j:P])
    end
end

ad_jLx = ForwardDiff.gradient(Lx -> wishart_cholesky_b_lpdf(Lx, df, L), Lx)
isapprox(jLx, ad_jLx)

# d/dn

# derivative of logmvgamma
df_logmvg = 0
for p in 1:P
    df_logmvg += 0.5 * digamma(0.5 * df + (1 - p) * 0.5)
end
df_logmvg

ad_df_logmvg = ForwardDiff.derivative(df -> logmvgamma(P, 0.5 * df), df)
isapprox(df_logmvg, ad_df_logmvg)

ad_jn = ForwardDiff.derivative(df -> wishart_cholesky_b_lpdf(Lx, df, L), df)
jn = -sum(log.(diag(L))) + sum(log.(diag(Lx))) - 0.5 * P * logtwo - df_logmvg

isapprox(ad_jn, jn)