JuliaStats / Distributions.jl

A Julia package for probability distributions and associated functions.
Other
1.08k stars 410 forks source link

MvNormal constructor unnecessarily recomputes cholesky every time with ForwardDiff #1781

Closed marius311 closed 9 months ago

marius311 commented 9 months ago

In a typical Turing model even if you tried to pre-convert an MvNormal covariance to a PDMat, it'll still get cholseky-ed every gradient call if using ForwardDiff:

using Turing, Distributions, ForwardDiff, LinearAlgebra
Turing.setadbackend(:forwarddiff)

# pirate this function just to see the call every time
@eval LinearAlgebra function cholesky(A::AbstractMatrix, ::NoPivot=NoPivot(); check::Bool = true)
    println("here") 
    cholesky!(cholcopy(A); check)
end

@model function foo(Σ, d)
    μ ~ filldist(Uniform(), 2)
    d ~ MvNormal(μ, Σ)
end

Σ = Distributions.PDMat([1 0; 0 2])
d = rand(2)
model = foo(Σ, d)

ForwardDiff.gradient(μ -> logjoint(model, (;μ)), rand(2)) # "here" printed every time

Imo the problem is the MvNormal constructor is a little too greedy promoting things when the eltypes don't match, and in the process recomputing cholesky (in this case the mean is Duals but the covariance is Floats).

A workaround is for the user to use the MvNormal{T,Cov,Mean}(...) constructor but this seems like an easy and potentially big performance footgun for users, even ones who were smart enough to try to manually do PDMat, and would be nice to fix (in some package, my sense is here, but maybe elsewhere).

devmotion commented 9 months ago

Of course, allocations would be minimized if MvNormal just takes whatever types the user puts into the struct. But I assume then it becomes much more likely to run into type instability issues and you have to be much more careful when implementing methods operating with MvNormal.

Regardless though, the unnecessary cholesky decompositions are not caused by inefficiences or bugs in Distributions but rather inefficient and missing convert definitions in PDMats. https://github.com/JuliaStats/PDMats.jl/pull/179 fixes your example.