Open Sahel13 opened 1 year ago
That's extremely weird. Do you know what's causing the performance difference?
It's interesting to note that this difference does not exist if samples
is a Vector{Vector}
instead of a matrix.
using LinearAlgebra, Random, Distributions, BenchmarkTools, Test
samples = [randn(5) for _ in 1:1000]
dist = MvNormal(zeros(5), I)
function mod_logpdf(dist, samples)
out = Array{Float64}(undef, len(samples))
logpdf!(out, dist, samples)
end
julia> @btime logpdf($dist, $samples);
89.132 μs (1001 allocations: 101.69 KiB)
julia> @btime mod_logpdf($dist, $samples);
88.147 μs (1001 allocations: 101.69 KiB)
So solely for the case where x
is a matrix, logpdf!
is unusually fast.
logpdf!(..., x::AbstractMatrix{<:Real})
calls a method in mvnormal.jl
function _logpdf!(r::AbstractArray{<:Real}, d::AbstractMvNormal, x::AbstractMatrix{<:Real})
sqmahal!(r, d, x)
c0 = mvnormal_c0(d)
for i = 1:size(x, 2)
@inbounds r[i] = c0 - r[i]/2
end
r
end
I'm guessing this function is the cause of this discrepancy, although why this is faster than the other methods, I do not know.
Could you run a profiler on this? You w should then see where it is spending the additional time.
logpdf
julia> Profile.clear()
julia> @profile (for _ in 1:1000; logpdf(dist, samples); end)
julia> Profile.print()
Overhead ╎ [+additional indent] Count File:Line; Function
=========================================================
╎86 @Base/task.jl:514; (::VSCodeServer.var"#62#63")()
╎ 86 @VSCodeServer/src/eval.jl:34; macro expansion
╎ 86 @Base/essentials.jl:816; invokelatest(::Any)
╎ 86 @Base/essentials.jl:819; #invokelatest#2
╎ 86 @VSCodeServer/src/repl.jl:193; (::VSCodeServer.var"#109#111"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
╎ 86 @Base/logging.jl:626; with_logger
╎ ╎ 86 @Base/logging.jl:514; with_logstate(f::Function, logstate::Any)
╎ ╎ 86 @VSCodeServer/src/repl.jl:192; (::VSCodeServer.var"#110#112"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
╎ ╎ 86 @VSCodeServer/src/repl.jl:229; repleval(m::Module, code::Expr, #unused#::String)
╎ ╎ 86 @Base/Base.jl:68; eval
╎ ╎ 86 @Base/boot.jl:370; eval
╎ ╎ ╎ 86 ...ot-9/usr/share/julia/stdlib/v1.9/Profile/src/Profile.jl:27; top-level scope
╎ ╎ ╎ 86 REPL[20]:1; macro expansion
╎ ╎ ╎ 86 @Distributions/src/common.jl:319; logpdf(d::IsoNormal, x::Matrix{Float64})
╎ ╎ ╎ 86 @Base/abstractarray.jl:3263; map
╎ ╎ ╎ 86 @Base/array.jl:711; collect_similar
╎ ╎ ╎ ╎ 86 @Base/array.jl:812; _collect(c::Distributions.EachVariate{1, Matrix{Float64}, Tuple{Base.OneTo{Int64}...
╎ ╎ ╎ ╎ 86 @Base/array.jl:818; collect_to_with_first!
╎ ╎ ╎ ╎ 86 @Base/array.jl:840; collect_to!(dest::Vector{Float64}, itr::Base.Generator{Distributions.EachVariate...
╎ ╎ ╎ ╎ 1 @Base/generator.jl:44; iterate
╎ ╎ ╎ ╎ 1 @Base/abstractarray.jl:1220; iterate
╎ ╎ ╎ ╎ ╎ 1 @Base/range.jl:891; iterate
1╎ ╎ ╎ ╎ ╎ 1 @Base/promotion.jl:499; ==
╎ ╎ ╎ ╎ 85 @Base/generator.jl:47; iterate
1╎ ╎ ╎ ╎ 85 @Base/operators.jl:1108; (::Base.Fix1{typeof(logpdf), IsoNormal})(y::SubArray{Float64, 1, Matrix{Float6...
╎ ╎ ╎ ╎ ╎ 84 @Distributions/src/common.jl:250; logpdf
╎ ╎ ╎ ╎ ╎ 84 @Distributions/src/multivariate/mvnormal.jl:143; _logpdf
1╎ ╎ ╎ ╎ ╎ 1 @Base/float.jl:409; -
╎ ╎ ╎ ╎ ╎ 9 @Distributions/src/multivariate/mvnormal.jl:101; mvnormal_c0
╎ ╎ ╎ ╎ ╎ 9 @Distributions/src/multivariate/mvnormal.jl:263; logdetcov
╎ ╎ ╎ ╎ ╎ 9 @PDMats/src/scalmat.jl:65; logdet
╎ ╎ ╎ ╎ ╎ ╎ 9 @Base/special/log.jl:267; log
1╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/special/log.jl:0; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
1╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/special/log.jl:270; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
1╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/special/log.jl:275; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
╎ ╎ ╎ ╎ ╎ ╎ 6 @Base/special/log.jl:277; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
╎ ╎ ╎ ╎ ╎ ╎ 2 @Base/special/log.jl:196; log_proc2
╎ ╎ ╎ ╎ ╎ ╎ 2 @Base/operators.jl:578; *
2╎ ╎ ╎ ╎ ╎ ╎ 2 @Base/float.jl:410; *
╎ ╎ ╎ ╎ ╎ ╎ 4 @Base/special/log.jl:215; log_proc2
╎ ╎ ╎ ╎ ╎ ╎ 4 @Base/floatfuncs.jl:426; fma
4╎ ╎ ╎ ╎ ╎ ╎ 4 @Base/floatfuncs.jl:421; fma_llvm
╎ ╎ ╎ ╎ ╎ 3 @Distributions/src/multivariate/mvnormal.jl:102; mvnormal_c0
2╎ ╎ ╎ ╎ ╎ 2 @Base/float.jl:408; +
╎ ╎ ╎ ╎ ╎ 1 @Base/promotion.jl:413; /
1╎ ╎ ╎ ╎ ╎ 1 @Base/float.jl:411; /
╎ ╎ ╎ ╎ ╎ 71 @Distributions/src/multivariate/mvnormal.jl:267; sqmahal(d::IsoNormal, x::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Sli...
1╎ ╎ ╎ ╎ ╎ 1 @Base/Base.jl:37; getproperty
╎ ╎ ╎ ╎ ╎ 60 @Base/broadcast.jl:873; materialize
╎ ╎ ╎ ╎ ╎ 58 @Base/broadcast.jl:898; copy
╎ ╎ ╎ ╎ ╎ ╎ 6 @Base/broadcast.jl:926; copyto!
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:970; copyto!
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:953; preprocess
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:956; preprocess_args
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:957; preprocess_args
╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:954; preprocess
╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:947; broadcast_unalias
╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/abstractarray.jl:1482; unalias
╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/abstractarray.jl:1517; mightalias
╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/abstractarray.jl:1541; dataids
╎ ╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/abstractarray.jl:1242; pointer
1╎ ╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/pointer.jl:65; unsafe_convert
╎ ╎ ╎ ╎ ╎ ╎ 5 @Base/broadcast.jl:973; copyto!
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/simdloop.jl:72; macro expansion
1╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/int.jl:83; <
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/simdloop.jl:76; macro expansion
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/simdloop.jl:54; simd_index
1╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/int.jl:87; +
╎ ╎ ╎ ╎ ╎ ╎ 3 @Base/simdloop.jl:77; macro expansion
╎ ╎ ╎ ╎ ╎ ╎ 3 @Base/broadcast.jl:974; macro expansion
2╎ ╎ ╎ ╎ ╎ ╎ 2 @Base/array.jl:969; setindex!
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:610; getindex
╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:656; _broadcast_getindex
╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:683; _broadcast_getindex_evalf
1╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/float.jl:409; -
╎ ╎ ╎ ╎ ╎ ╎ 52 @Base/broadcast.jl:211; similar
╎ ╎ ╎ ╎ ╎ ╎ 52 @Base/broadcast.jl:212; similar
╎ ╎ ╎ ╎ ╎ ╎ 52 @Base/abstractarray.jl:883; similar
╎ ╎ ╎ ╎ ╎ ╎ 52 @Base/abstractarray.jl:884; similar
╎ ╎ ╎ ╎ ╎ ╎ 52 @Base/boot.jl:494; Array
╎ ╎ ╎ ╎ ╎ ╎ ╎ 52 @Base/boot.jl:486; Array
51╎ ╎ ╎ ╎ ╎ ╎ ╎ 52 @Base/boot.jl:477; Array
╎ ╎ ╎ ╎ ╎ 2 @Base/broadcast.jl:294; instantiate
╎ ╎ ╎ ╎ ╎ ╎ 2 @Base/broadcast.jl:512; combine_axes
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/abstractarray.jl:98; axes
1╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/array.jl:149; size
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:517; broadcast_shape
╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:523; _bcs
1╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/broadcast.jl:529; _bcs1
╎ ╎ ╎ ╎ ╎ 10 @PDMats/src/scalmat.jl:87; invquad
3╎ ╎ ╎ ╎ ╎ 3 @Base/float.jl:411; /
╎ ╎ ╎ ╎ ╎ 7 @Base/reducedim.jl:995; sum
╎ ╎ ╎ ╎ ╎ ╎ 7 @Base/reducedim.jl:995; #sum#808
╎ ╎ ╎ ╎ ╎ ╎ 7 @Base/reducedim.jl:999; _sum
╎ ╎ ╎ ╎ ╎ ╎ 7 @Base/reducedim.jl:999; #_sum#810
╎ ╎ ╎ ╎ ╎ ╎ 7 @Base/reducedim.jl:357; mapreduce
╎ ╎ ╎ ╎ ╎ ╎ 7 @Base/reducedim.jl:357; #mapreduce#800
╎ ╎ ╎ ╎ ╎ ╎ ╎ 7 @Base/reducedim.jl:365; _mapreduce_dim
╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/reduce.jl:433; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
1╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/essentials.jl:13; getindex
╎ ╎ ╎ ╎ ╎ ╎ ╎ 3 @Base/reduce.jl:435; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
╎ ╎ ╎ ╎ ╎ ╎ ╎ 2 @Base/number.jl:189; abs2
2╎ ╎ ╎ ╎ ╎ ╎ ╎ 2 @Base/float.jl:410; *
╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/reduce.jl:27; add_sum
1╎ ╎ ╎ ╎ ╎ ╎ ╎ 1 @Base/float.jl:408; +
╎ ╎ ╎ ╎ ╎ ╎ ╎ 3 @Base/reduce.jl:436; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
3╎ ╎ ╎ ╎ ╎ ╎ ╎ 3 @Base/int.jl:83; <
Total snapshots: 89. Utilization: 100% across all threads and tasks. Use the `groupby` kwarg to break down by thread and/or task.
mod_logpdf
julia> Profile.clear()
julia> @profile (for _ in 1:1000; mod_logpdf(dist, samples); end)
julia> Profile.print()
Overhead ╎ [+additional indent] Count File:Line; Function
=========================================================
╎32 @Base/task.jl:514; (::VSCodeServer.var"#62#63")()
╎ 32 @VSCodeServer/src/eval.jl:34; macro expansion
╎ 32 @Base/essentials.jl:816; invokelatest(::Any)
╎ 32 @Base/essentials.jl:819; #invokelatest#2
╎ 32 @VSCodeServer/src/repl.jl:193; (::VSCodeServer.var"#109#111"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
╎ 32 @Base/logging.jl:626; with_logger
╎ ╎ 32 @Base/logging.jl:514; with_logstate(f::Function, logstate::Any)
╎ ╎ 32 @VSCodeServer/src/repl.jl:192; (::VSCodeServer.var"#110#112"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
╎ ╎ 32 @VSCodeServer/src/repl.jl:229; repleval(m::Module, code::Expr, #unused#::String)
╎ ╎ 32 @Base/Base.jl:68; eval
╎ ╎ 32 @Base/boot.jl:370; eval
╎ ╎ ╎ 32 ...ot-9/usr/share/julia/stdlib/v1.9/Profile/src/Profile.jl:27; top-level scope
1╎ ╎ ╎ 32 REPL[23]:1; macro expansion
╎ ╎ ╎ 2 /Users/sahel/Code/InsideOutSMC.jl/experiments/testing.jl:7; mod_logpdf(dist::IsoNormal, samples::Matrix{Float64})
╎ ╎ ╎ 2 @Base/boot.jl:491; Array
2╎ ╎ ╎ 2 @Base/boot.jl:477; Array
╎ ╎ ╎ 29 /Users/sahel/Code/InsideOutSMC.jl/experiments/testing.jl:8; mod_logpdf(dist::IsoNormal, samples::Matrix{Float64})
╎ ╎ ╎ 29 @Distributions/src/common.jl:424; logpdf!
╎ ╎ ╎ 29 @Distributions/src/multivariate/mvnormal.jl:146; _logpdf!(r::Vector{Float64}, d::IsoNormal, x::Matrix{Float64})
╎ ╎ ╎ ╎ 29 @Distributions/src/multivariate/mvnormal.jl:269; sqmahal!(r::Vector{Float64}, d::IsoNormal, x::Matrix{Float64})
╎ ╎ ╎ ╎ 28 @Base/broadcast.jl:873; materialize
╎ ╎ ╎ ╎ 28 @Base/broadcast.jl:898; copy
╎ ╎ ╎ ╎ 20 @Base/broadcast.jl:926; copyto!
╎ ╎ ╎ ╎ 20 @Base/broadcast.jl:973; copyto!
╎ ╎ ╎ ╎ ╎ 18 @Base/simdloop.jl:77; macro expansion
╎ ╎ ╎ ╎ ╎ 18 @Base/broadcast.jl:974; macro expansion
╎ ╎ ╎ ╎ ╎ 18 @Base/multidimensional.jl:670; setindex!
17╎ ╎ ╎ ╎ ╎ 18 @Base/array.jl:971; setindex!
╎ ╎ ╎ ╎ ╎ 2 @Base/simdloop.jl:78; macro expansion
2╎ ╎ ╎ ╎ ╎ 2 @Base/int.jl:87; +
╎ ╎ ╎ ╎ 8 @Base/broadcast.jl:211; similar
╎ ╎ ╎ ╎ 8 @Base/broadcast.jl:212; similar
╎ ╎ ╎ ╎ ╎ 8 @Base/abstractarray.jl:883; similar
╎ ╎ ╎ ╎ ╎ 8 @Base/abstractarray.jl:884; similar
╎ ╎ ╎ ╎ ╎ 8 @Base/boot.jl:494; Array
╎ ╎ ╎ ╎ ╎ 8 @Base/boot.jl:487; Array
8╎ ╎ ╎ ╎ ╎ 8 @Base/boot.jl:479; Array
╎ ╎ ╎ ╎ 1 @PDMats/src/scalmat.jl:90; invquad!
╎ ╎ ╎ ╎ 1 @PDMats/src/utils.jl:103; colwise_sumsqinv!(r::Vector{Float64}, a::Matrix{Float64}, c::Float64)
1╎ ╎ ╎ ╎ 1 @Base/range.jl:891; iterate
Total snapshots: 32. Utilization: 100% across all threads and tasks. Use the `groupby` kwarg to break down by thread and/or task.
There seems to be a lot more backtraces at materialize
in logpdf
, but I do not know how to infer why.
Hmm, do you have a flamegraph?
I'm assuming the generated images are what you asked for (I'm currently working on my first project in Julia, so there's a lot to learn).
Ahh, I was suggesting you might want to look at the flamegraphs to see which lines specifically are the ones slowing down logpdf
, sorry for not being clear about that. :sweat_smile:
Where is it spending most of its time?
Sorry, my bad XD. This is the profile view plot for logpdf
:
Most of the time seems to be taken up by sqmahal
. Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what logpdf
is doing) is slower than doing it for a matrix in one go (as done by logpdf!
)? I can think of the Cholesky factorization of the covariance matrix having to be computed only once in the latter case, for example.
I can think of the Cholesky factorization of the covariance matrix having to be computed only once in the latter case, for example.
The factorization is only computed once upfront when you construct an MvNormal
object.
Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what
logpdf
is doing) is slower than doing it for a matrix in one go (as done bylogpdf!
)?
Yep, that would be it. It's creating way more arrays. Could you make a PR to fix this?
Yes I can. But just a question, is it problematic if logpdf
calls the mutating version underneath? I don't know your design principles behind this package, but if we want to perform logpdf
to play well with autodiff, for example, we wouldn't want it to perform any in-place operations.
No, ideally we would not mix both paths, also eg for better compatibility with static arrays. Even though probably currently many methods don't work (in an optimized way) with static arrays.
Another reason is that generally it is quite challenging and brittle when one starts to come up with heuristics for the output type.
No, ideally we would not mix both paths, also eg for better compatibility with static arrays. Even though probably currently many methods don't work (in an optimized way) with static arrays.
Sorry, I don't understand whether you meant it's not a problem for logpdf
to call the mutating version, or whether it's better it doesn't. If it is the case that you would prefer a completely non-mutating version, then I do not know how to write a faster implementation.
I meant that generally logpdf
should be non-mutating, and in particular it should not make any assumptions about the type of the arrays it is called with and eg whether they are mutable or not.
Ok, thanks for the clarification. Then I'm afraid I don't know a solution to this.
Most of the time seems to be taken up by
sqmahal
. Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is whatlogpdf
is doing) is slower than doing it for a matrix in one go (as done bylogpdf!
)?
Quick question, is sqmahal
the main difference in time spent between logpdf
and logpdf!
? (You can benchmark both to see which lines make up most of the difference.) If it is, I think it should be possible to correct this by just doing all the calculations at once.
Problem
Consider the following minimal example.
logpdf
is around 7x slower thanmod_logpdf
, even though they both do exactly the same thing.Possible solution
Add a method
that does something like
mod_logpdf
.