Closed spinkney closed 2 years ago
Neat. That's really clean and should be OK to autodiff through in a direct C++ translation.
You can simplify even further replacing sum(diagonal(...))
with trace(...)
, which should have better autodiff properties. Or better yet, can we simplify without a call to crossprod
given that we only need the diagonal of the result? The whole crossprod
function introduces a lot of autodiff overhead.
I think we only need the "b" version. The important case is where the variant is a Cholesky factor because that's the efficient data structure for covariance matrix parameters.
Ideally, we'd implement analytic derivatives through adjoint-Jacobian product updates. That is, let
lp = wishart_cholesky_b_lpdf(Lx, n, L)
and figure out how to update Lx.adjoint
and L.adjoint
given lp
, Lx
, L
, and lp.adjoint
. It requires an adjoint-Jacobian product, but given the output is univariate, that's just an adjoint-gradient product,
L.adjoint += lp.adjoint * d.lp / d.L
Lx.adjoint += lp.adjoint * d.lp / d.Lx
I'll see if I can work out those adjoints. As for the crossprod
, we can get rid of it because that matrix is lower triangular. In cpp we can use a lower triangular type to store and then a lower triangular multiply.
The updated version would be
real wishart_cholesky_b_lpdf(matrix Lx, real n, matrix L) {
int P = rows(L);
real lp = (n - P - 1) * sum(log(diagonal(Lx)));
matrix[P, P] LinvLx = mdivide_left_tri_low(L, Lx);
// normalizing constant
// lp -= 0.5 * (n * P) * log2() + lmgamma(P, 0.5 * n);
lp -= 0.5 * trace(multiply_lower_tri_self_transpose(LinvLx));
return lp - n * sum(log(diagonal(L)));
}
we can use a lower triangular type to store and then a lower triangular multiply.
Thanks @spinkney.
Given that we only need the diagonal elements, it'll be a lot more efficient in memory to do this
for (p in 1:P)
lp -= 0.5 * dot_product(LinvLx[p, 1:p]);
@bob-carpenter the b version needs a change of variables if we want to use this directly on Lx.
real wishart_cholesky_b_lpdf(matrix Lx, real n, matrix L) {
int P = rows(L);
matrix[P, P] LinvLx = mdivide_left_tri_low(L, Lx);
// normalizing constant
real lp = -(P * log2() * (1 - 0.5 * n) + lmgamma(P, 0.5 * n));
for (p in 1:P) {
lp -= 0.5 * dot_self(LinvLx[p, 1:p]) + n * log(L[p, p]) - (n - p) * log(Lx[p, p]);
}
return lp;
}
That's from slide 36 of https://www.maxturgeon.ca/w20-stat7200/slides/wishart-distribution.pdf. I believe that the normalizing constant in that presentation is wrong (it should have |\Sigma|^{n/2})
For my notes. Derivatives wrt to the parameters (written in Julia)
using Distributions
using LinearAlgebra
using ForwardDiff
using SpecialFunctions
using StatsFuns
using BenchmarkTools
using LoopVectorization
function jselfdotavx(a)
s = zero(eltype(a))
@avx for i ∈ eachindex(a)
s += a[i] * a[i]
end
s
end
function logdettriangle(B::Union{LowerTriangular,UpperTriangular})
A = parent(B) # using a triangular matrix would fall back to the default loop.
ld = zero(eltype(A))
@avx for n ∈ axes(A,1)
ld += log(A[n,n])
end
ld
end
wishart_cholesky_b_lpdf = function(Lx::LowerTriangular, n::Real, L::LowerTriangular)
P::Int64 = size(L)[1]
# lp = (n - P - 1) * sum(log.(diag(Lx)))
LinvLx::LowerTriangular = L \ Lx
# normalizing constant
lp = P * logtwo * (1 - 0.5 * n) - logmvgamma(P, 0.5 * n)
# lp -= n * logdettriangle(L)
# lp += (n - P - 1) * logdettriangle(Lx)
for p in 1:P
lp -= 0.5 * jselfdotavx(LinvLx[p, 1:p]) + n * log(L[p, p]) - (n - p) * log(Lx[p, p])
end
return lp
end
lkj = LKJCholesky(4, 1)
Lx = rand(lkj)
L = rand(lkj)
L = cholesky(Matrix(L)).L
Lx = cholesky(Matrix(Lx)).L
# d/dL
dLx = function(L, Lx)
LinvLx = L \ Lx;
return -0.5 * tr( LinvLx * LinvLx')
end
ForwardDiff.gradient(L -> dLx(L, Lx), L)
LinvLx = L \ Lx
LinvLx = L \ Lx
Linvt = inv(L')
mlt_LinvLx = LinvLx * LinvLx'
# calculates
# tril(inv(L') * (LinvLx * LinvLx'))
P = size(L)[1]
jL = zero(L)
s = 0.
for i ∈ 1:P
jL[i, i] = -df / L[i, i]
for j ∈ i:P
jL[j, i] += sum(mlt_LinvLx[j:P, i] .* Linvt[j , j:P])
end
end
ad_jL = ForwardDiff.gradient(L -> wishart_cholesky_b_lpdf(Lx, df, L), L)
isapprox(jL, ad_jL)
# d/dLx
ForwardDiff.gradient(Lx -> dLx(L, Lx), Lx)
LinvLx = L \ Lx
Linvt = inv(L')
P = size(Lx)[1]
jLx = zero(Lx)
s = 0.
for i ∈ 1:P
jLx[i, i] = (df - i) / Lx[i, i]
for j ∈ i:P
jLx[j, i] += -sum(LinvLx[j:P, i] .* Linvt[j , j:P])
end
end
ad_jLx = ForwardDiff.gradient(Lx -> wishart_cholesky_b_lpdf(Lx, df, L), Lx)
isapprox(jLx, ad_jLx)
# d/dn
# derivative of logmvgamma
df_logmvg = 0
for p in 1:P
df_logmvg += 0.5 * digamma(0.5 * df + (1 - p) * 0.5)
end
df_logmvg
ad_df_logmvg = ForwardDiff.derivative(df -> logmvgamma(P, 0.5 * df), df)
isapprox(df_logmvg, ad_df_logmvg)
ad_jn = ForwardDiff.derivative(df -> wishart_cholesky_b_lpdf(Lx, df, L), df)
jn = -sum(log.(diag(L))) + sum(log.(diag(Lx))) - 0.5 * P * logtwo - df_logmvg
isapprox(ad_jn, jn)
We can add the Cholesky parameterization on both X and Sigma or just on Sigma
In cpp we'd use LDLt so we don't have to
add_diag
.The results equal to what the
CholWishart
R package has