JuliaStats / Distributions.jl

A Julia package for probability distributions and associated functions.
Other
1.1k stars 414 forks source link

Inconsistent performance between logpdf and logpdf! for MvNormal. #1775

Open Sahel13 opened 1 year ago

Sahel13 commented 1 year ago

Problem

Consider the following minimal example.

using LinearAlgebra, Random, Distributions, BenchmarkTools, Test

samples = randn(5, 1000)
dist = MvNormal(zeros(5), I)

function mod_logpdf(dist, samples)
    out = Array{Float64}(undef, size(samples, 2))
    logpdf!(out, dist, samples)
end
julia> @test logpdf(dist, samples) ≈ mod_logpdf(dist, samples)
Test Passed
julia> @btime logpdf($dist, $samples);
  89.290 μs (1001 allocations: 101.69 KiB)
julia> @btime mod_logpdf($dist, $samples);
  12.672 μs (3 allocations: 47.05 KiB)

logpdf is around 7x slower than mod_logpdf, even though they both do exactly the same thing.

Possible solution

Add a method

logpdf(d::AbstractMvNormal, x::AbstractMatrix{<:Real})

that does something like mod_logpdf.

ParadaCarleton commented 1 year ago

That's extremely weird. Do you know what's causing the performance difference?

Sahel13 commented 1 year ago

It's interesting to note that this difference does not exist if samples is a Vector{Vector} instead of a matrix.

using LinearAlgebra, Random, Distributions, BenchmarkTools, Test

samples = [randn(5) for _ in 1:1000]
dist = MvNormal(zeros(5), I)

function mod_logpdf(dist, samples)
    out = Array{Float64}(undef, len(samples))
    logpdf!(out, dist, samples)
end
julia> @btime logpdf($dist, $samples);
  89.132 μs (1001 allocations: 101.69 KiB)
julia> @btime mod_logpdf($dist, $samples);
  88.147 μs (1001 allocations: 101.69 KiB)

So solely for the case where x is a matrix, logpdf! is unusually fast.

logpdf!(..., x::AbstractMatrix{<:Real}) calls a method in mvnormal.jl

function _logpdf!(r::AbstractArray{<:Real}, d::AbstractMvNormal, x::AbstractMatrix{<:Real})
    sqmahal!(r, d, x)
    c0 = mvnormal_c0(d)
    for i = 1:size(x, 2)
        @inbounds r[i] = c0 - r[i]/2
    end
    r
end

I'm guessing this function is the cause of this discrepancy, although why this is faster than the other methods, I do not know.

simsurace commented 1 year ago

Could you run a profiler on this? You w should then see where it is spending the additional time.

Sahel13 commented 1 year ago

For logpdf

julia> Profile.clear()

julia> @profile (for _ in 1:1000; logpdf(dist, samples); end)

julia> Profile.print()
Overhead ╎ [+additional indent] Count File:Line; Function
=========================================================
  ╎86 @Base/task.jl:514; (::VSCodeServer.var"#62#63")()
  ╎ 86 @VSCodeServer/src/eval.jl:34; macro expansion
  ╎  86 @Base/essentials.jl:816; invokelatest(::Any)
  ╎   86 @Base/essentials.jl:819; #invokelatest#2
  ╎    86 @VSCodeServer/src/repl.jl:193; (::VSCodeServer.var"#109#111"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎     86 @Base/logging.jl:626; with_logger
  ╎    ╎ 86 @Base/logging.jl:514; with_logstate(f::Function, logstate::Any)
  ╎    ╎  86 @VSCodeServer/src/repl.jl:192; (::VSCodeServer.var"#110#112"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎    ╎   86 @VSCodeServer/src/repl.jl:229; repleval(m::Module, code::Expr, #unused#::String)
  ╎    ╎    86 @Base/Base.jl:68; eval
  ╎    ╎     86 @Base/boot.jl:370; eval
  ╎    ╎    ╎ 86 ...ot-9/usr/share/julia/stdlib/v1.9/Profile/src/Profile.jl:27; top-level scope
  ╎    ╎    ╎  86 REPL[20]:1; macro expansion
  ╎    ╎    ╎   86 @Distributions/src/common.jl:319; logpdf(d::IsoNormal, x::Matrix{Float64})
  ╎    ╎    ╎    86 @Base/abstractarray.jl:3263; map
  ╎    ╎    ╎     86 @Base/array.jl:711; collect_similar
  ╎    ╎    ╎    ╎ 86 @Base/array.jl:812; _collect(c::Distributions.EachVariate{1, Matrix{Float64}, Tuple{Base.OneTo{Int64}...
  ╎    ╎    ╎    ╎  86 @Base/array.jl:818; collect_to_with_first!
  ╎    ╎    ╎    ╎   86 @Base/array.jl:840; collect_to!(dest::Vector{Float64}, itr::Base.Generator{Distributions.EachVariate...
  ╎    ╎    ╎    ╎    1  @Base/generator.jl:44; iterate
  ╎    ╎    ╎    ╎     1  @Base/abstractarray.jl:1220; iterate
  ╎    ╎    ╎    ╎    ╎ 1  @Base/range.jl:891; iterate
 1╎    ╎    ╎    ╎    ╎  1  @Base/promotion.jl:499; ==
  ╎    ╎    ╎    ╎    85 @Base/generator.jl:47; iterate
 1╎    ╎    ╎    ╎     85 @Base/operators.jl:1108; (::Base.Fix1{typeof(logpdf), IsoNormal})(y::SubArray{Float64, 1, Matrix{Float6...
  ╎    ╎    ╎    ╎    ╎ 84 @Distributions/src/common.jl:250; logpdf
  ╎    ╎    ╎    ╎    ╎  84 @Distributions/src/multivariate/mvnormal.jl:143; _logpdf
 1╎    ╎    ╎    ╎    ╎   1  @Base/float.jl:409; -
  ╎    ╎    ╎    ╎    ╎   9  @Distributions/src/multivariate/mvnormal.jl:101; mvnormal_c0
  ╎    ╎    ╎    ╎    ╎    9  @Distributions/src/multivariate/mvnormal.jl:263; logdetcov
  ╎    ╎    ╎    ╎    ╎     9  @PDMats/src/scalmat.jl:65; logdet
  ╎    ╎    ╎    ╎    ╎    ╎ 9  @Base/special/log.jl:267; log
 1╎    ╎    ╎    ╎    ╎    ╎  1  @Base/special/log.jl:0; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
 1╎    ╎    ╎    ╎    ╎    ╎  1  @Base/special/log.jl:270; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
 1╎    ╎    ╎    ╎    ╎    ╎  1  @Base/special/log.jl:275; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
  ╎    ╎    ╎    ╎    ╎    ╎  6  @Base/special/log.jl:277; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
  ╎    ╎    ╎    ╎    ╎    ╎   2  @Base/special/log.jl:196; log_proc2
  ╎    ╎    ╎    ╎    ╎    ╎    2  @Base/operators.jl:578; *
 2╎    ╎    ╎    ╎    ╎    ╎     2  @Base/float.jl:410; *
  ╎    ╎    ╎    ╎    ╎    ╎   4  @Base/special/log.jl:215; log_proc2
  ╎    ╎    ╎    ╎    ╎    ╎    4  @Base/floatfuncs.jl:426; fma
 4╎    ╎    ╎    ╎    ╎    ╎     4  @Base/floatfuncs.jl:421; fma_llvm
  ╎    ╎    ╎    ╎    ╎   3  @Distributions/src/multivariate/mvnormal.jl:102; mvnormal_c0
 2╎    ╎    ╎    ╎    ╎    2  @Base/float.jl:408; +
  ╎    ╎    ╎    ╎    ╎    1  @Base/promotion.jl:413; /
 1╎    ╎    ╎    ╎    ╎     1  @Base/float.jl:411; /
  ╎    ╎    ╎    ╎    ╎   71 @Distributions/src/multivariate/mvnormal.jl:267; sqmahal(d::IsoNormal, x::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Sli...
 1╎    ╎    ╎    ╎    ╎    1  @Base/Base.jl:37; getproperty
  ╎    ╎    ╎    ╎    ╎    60 @Base/broadcast.jl:873; materialize
  ╎    ╎    ╎    ╎    ╎     58 @Base/broadcast.jl:898; copy
  ╎    ╎    ╎    ╎    ╎    ╎ 6  @Base/broadcast.jl:926; copyto!
  ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:970; copyto!
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/broadcast.jl:953; preprocess
  ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/broadcast.jl:956; preprocess_args
  ╎    ╎    ╎    ╎    ╎    ╎     1  @Base/broadcast.jl:957; preprocess_args
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 1  @Base/broadcast.jl:954; preprocess
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:947; broadcast_unalias
  ╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/abstractarray.jl:1482; unalias
  ╎    ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/abstractarray.jl:1517; mightalias
  ╎    ╎    ╎    ╎    ╎    ╎    ╎     1  @Base/abstractarray.jl:1541; dataids
  ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ 1  @Base/abstractarray.jl:1242; pointer
 1╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/pointer.jl:65; unsafe_convert
  ╎    ╎    ╎    ╎    ╎    ╎  5  @Base/broadcast.jl:973; copyto!
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/simdloop.jl:72; macro expansion
 1╎    ╎    ╎    ╎    ╎    ╎    1  @Base/int.jl:83; <
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/simdloop.jl:76; macro expansion
  ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/simdloop.jl:54; simd_index
 1╎    ╎    ╎    ╎    ╎    ╎     1  @Base/int.jl:87; +
  ╎    ╎    ╎    ╎    ╎    ╎   3  @Base/simdloop.jl:77; macro expansion
  ╎    ╎    ╎    ╎    ╎    ╎    3  @Base/broadcast.jl:974; macro expansion
 2╎    ╎    ╎    ╎    ╎    ╎     2  @Base/array.jl:969; setindex!
  ╎    ╎    ╎    ╎    ╎    ╎     1  @Base/broadcast.jl:610; getindex
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 1  @Base/broadcast.jl:656; _broadcast_getindex
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:683; _broadcast_getindex_evalf
 1╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/float.jl:409; -
  ╎    ╎    ╎    ╎    ╎    ╎ 52 @Base/broadcast.jl:211; similar
  ╎    ╎    ╎    ╎    ╎    ╎  52 @Base/broadcast.jl:212; similar
  ╎    ╎    ╎    ╎    ╎    ╎   52 @Base/abstractarray.jl:883; similar
  ╎    ╎    ╎    ╎    ╎    ╎    52 @Base/abstractarray.jl:884; similar
  ╎    ╎    ╎    ╎    ╎    ╎     52 @Base/boot.jl:494; Array
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 52 @Base/boot.jl:486; Array
51╎    ╎    ╎    ╎    ╎    ╎    ╎  52 @Base/boot.jl:477; Array
  ╎    ╎    ╎    ╎    ╎     2  @Base/broadcast.jl:294; instantiate
  ╎    ╎    ╎    ╎    ╎    ╎ 2  @Base/broadcast.jl:512; combine_axes
  ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/abstractarray.jl:98; axes
 1╎    ╎    ╎    ╎    ╎    ╎   1  @Base/array.jl:149; size
  ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:517; broadcast_shape
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/broadcast.jl:523; _bcs
 1╎    ╎    ╎    ╎    ╎    ╎    1  @Base/broadcast.jl:529; _bcs1
  ╎    ╎    ╎    ╎    ╎    10 @PDMats/src/scalmat.jl:87; invquad
 3╎    ╎    ╎    ╎    ╎     3  @Base/float.jl:411; /
  ╎    ╎    ╎    ╎    ╎     7  @Base/reducedim.jl:995; sum
  ╎    ╎    ╎    ╎    ╎    ╎ 7  @Base/reducedim.jl:995; #sum#808
  ╎    ╎    ╎    ╎    ╎    ╎  7  @Base/reducedim.jl:999; _sum
  ╎    ╎    ╎    ╎    ╎    ╎   7  @Base/reducedim.jl:999; #_sum#810
  ╎    ╎    ╎    ╎    ╎    ╎    7  @Base/reducedim.jl:357; mapreduce
  ╎    ╎    ╎    ╎    ╎    ╎     7  @Base/reducedim.jl:357; #mapreduce#800
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 7  @Base/reducedim.jl:365; _mapreduce_dim
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/reduce.jl:433; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
 1╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/essentials.jl:13; getindex
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  3  @Base/reduce.jl:435; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
  ╎    ╎    ╎    ╎    ╎    ╎    ╎   2  @Base/number.jl:189; abs2
 2╎    ╎    ╎    ╎    ╎    ╎    ╎    2  @Base/float.jl:410; *
  ╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/reduce.jl:27; add_sum
 1╎    ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/float.jl:408; +
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  3  @Base/reduce.jl:436; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
 3╎    ╎    ╎    ╎    ╎    ╎    ╎   3  @Base/int.jl:83; <
Total snapshots: 89. Utilization: 100% across all threads and tasks. Use the `groupby` kwarg to break down by thread and/or task.

For mod_logpdf

julia> Profile.clear()

julia> @profile (for _ in 1:1000; mod_logpdf(dist, samples); end)

julia> Profile.print()
Overhead ╎ [+additional indent] Count File:Line; Function
=========================================================
  ╎32 @Base/task.jl:514; (::VSCodeServer.var"#62#63")()
  ╎ 32 @VSCodeServer/src/eval.jl:34; macro expansion
  ╎  32 @Base/essentials.jl:816; invokelatest(::Any)
  ╎   32 @Base/essentials.jl:819; #invokelatest#2
  ╎    32 @VSCodeServer/src/repl.jl:193; (::VSCodeServer.var"#109#111"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎     32 @Base/logging.jl:626; with_logger
  ╎    ╎ 32 @Base/logging.jl:514; with_logstate(f::Function, logstate::Any)
  ╎    ╎  32 @VSCodeServer/src/repl.jl:192; (::VSCodeServer.var"#110#112"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎    ╎   32 @VSCodeServer/src/repl.jl:229; repleval(m::Module, code::Expr, #unused#::String)
  ╎    ╎    32 @Base/Base.jl:68; eval
  ╎    ╎     32 @Base/boot.jl:370; eval
  ╎    ╎    ╎ 32 ...ot-9/usr/share/julia/stdlib/v1.9/Profile/src/Profile.jl:27; top-level scope
 1╎    ╎    ╎  32 REPL[23]:1; macro expansion
  ╎    ╎    ╎   2  /Users/sahel/Code/InsideOutSMC.jl/experiments/testing.jl:7; mod_logpdf(dist::IsoNormal, samples::Matrix{Float64})
  ╎    ╎    ╎    2  @Base/boot.jl:491; Array
 2╎    ╎    ╎     2  @Base/boot.jl:477; Array
  ╎    ╎    ╎   29 /Users/sahel/Code/InsideOutSMC.jl/experiments/testing.jl:8; mod_logpdf(dist::IsoNormal, samples::Matrix{Float64})
  ╎    ╎    ╎    29 @Distributions/src/common.jl:424; logpdf!
  ╎    ╎    ╎     29 @Distributions/src/multivariate/mvnormal.jl:146; _logpdf!(r::Vector{Float64}, d::IsoNormal, x::Matrix{Float64})
  ╎    ╎    ╎    ╎ 29 @Distributions/src/multivariate/mvnormal.jl:269; sqmahal!(r::Vector{Float64}, d::IsoNormal, x::Matrix{Float64})
  ╎    ╎    ╎    ╎  28 @Base/broadcast.jl:873; materialize
  ╎    ╎    ╎    ╎   28 @Base/broadcast.jl:898; copy
  ╎    ╎    ╎    ╎    20 @Base/broadcast.jl:926; copyto!
  ╎    ╎    ╎    ╎     20 @Base/broadcast.jl:973; copyto!
  ╎    ╎    ╎    ╎    ╎ 18 @Base/simdloop.jl:77; macro expansion
  ╎    ╎    ╎    ╎    ╎  18 @Base/broadcast.jl:974; macro expansion
  ╎    ╎    ╎    ╎    ╎   18 @Base/multidimensional.jl:670; setindex!
17╎    ╎    ╎    ╎    ╎    18 @Base/array.jl:971; setindex!
  ╎    ╎    ╎    ╎    ╎ 2  @Base/simdloop.jl:78; macro expansion
 2╎    ╎    ╎    ╎    ╎  2  @Base/int.jl:87; +
  ╎    ╎    ╎    ╎    8  @Base/broadcast.jl:211; similar
  ╎    ╎    ╎    ╎     8  @Base/broadcast.jl:212; similar
  ╎    ╎    ╎    ╎    ╎ 8  @Base/abstractarray.jl:883; similar
  ╎    ╎    ╎    ╎    ╎  8  @Base/abstractarray.jl:884; similar
  ╎    ╎    ╎    ╎    ╎   8  @Base/boot.jl:494; Array
  ╎    ╎    ╎    ╎    ╎    8  @Base/boot.jl:487; Array
 8╎    ╎    ╎    ╎    ╎     8  @Base/boot.jl:479; Array
  ╎    ╎    ╎    ╎  1  @PDMats/src/scalmat.jl:90; invquad!
  ╎    ╎    ╎    ╎   1  @PDMats/src/utils.jl:103; colwise_sumsqinv!(r::Vector{Float64}, a::Matrix{Float64}, c::Float64)
 1╎    ╎    ╎    ╎    1  @Base/range.jl:891; iterate
Total snapshots: 32. Utilization: 100% across all threads and tasks. Use the `groupby` kwarg to break down by thread and/or task.

There seems to be a lot more backtraces at materialize in logpdf, but I do not know how to infer why.

ParadaCarleton commented 1 year ago

Hmm, do you have a flamegraph?

Sahel13 commented 1 year ago

I'm assuming the generated images are what you asked for (I'm currently working on my first project in Julia, so there's a lot to learn).

logpdf

logpdf

mod_logpdf

mod_logpdf

ParadaCarleton commented 1 year ago

Ahh, I was suggesting you might want to look at the flamegraphs to see which lines specifically are the ones slowing down logpdf, sorry for not being clear about that. :sweat_smile:

Where is it spending most of its time?

Sahel13 commented 1 year ago

Sorry, my bad XD. This is the profile view plot for logpdf:

profile_view_logpdf

Most of the time seems to be taken up by sqmahal. Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what logpdf is doing) is slower than doing it for a matrix in one go (as done by logpdf!)? I can think of the Cholesky factorization of the covariance matrix having to be computed only once in the latter case, for example.

devmotion commented 1 year ago

I can think of the Cholesky factorization of the covariance matrix having to be computed only once in the latter case, for example.

The factorization is only computed once upfront when you construct an MvNormal object.

ParadaCarleton commented 1 year ago

Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what logpdf is doing) is slower than doing it for a matrix in one go (as done by logpdf!)?

Yep, that would be it. It's creating way more arrays. Could you make a PR to fix this?

Sahel13 commented 1 year ago

Yes I can. But just a question, is it problematic if logpdf calls the mutating version underneath? I don't know your design principles behind this package, but if we want to perform logpdf to play well with autodiff, for example, we wouldn't want it to perform any in-place operations.

devmotion commented 1 year ago

No, ideally we would not mix both paths, also eg for better compatibility with static arrays. Even though probably currently many methods don't work (in an optimized way) with static arrays.

devmotion commented 1 year ago

Another reason is that generally it is quite challenging and brittle when one starts to come up with heuristics for the output type.

Sahel13 commented 1 year ago

No, ideally we would not mix both paths, also eg for better compatibility with static arrays. Even though probably currently many methods don't work (in an optimized way) with static arrays.

Sorry, I don't understand whether you meant it's not a problem for logpdf to call the mutating version, or whether it's better it doesn't. If it is the case that you would prefer a completely non-mutating version, then I do not know how to write a faster implementation.

devmotion commented 1 year ago

I meant that generally logpdf should be non-mutating, and in particular it should not make any assumptions about the type of the arrays it is called with and eg whether they are mutable or not.

Sahel13 commented 1 year ago

Ok, thanks for the clarification. Then I'm afraid I don't know a solution to this.

ParadaCarleton commented 1 year ago

Most of the time seems to be taken up by sqmahal. Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what logpdf is doing) is slower than doing it for a matrix in one go (as done by logpdf!)?

Quick question, is sqmahal the main difference in time spent between logpdf and logpdf!? (You can benchmark both to see which lines make up most of the difference.) If it is, I think it should be possible to correct this by just doing all the calculations at once.