JuliaLang / LinearAlgebra.jl

Julia Linear Algebra standard library
13 stars 0 forks source link

The documentation for the logdet function is a bit misleading #1017

Closed bvdmitri closed 1 year ago

bvdmitri commented 1 year ago

The documentation for the logdet function says that the function may provide increased accuracy and/or speed. In many situations however, it is significantly slower than the log(det(...)) equivalent, especially if you work with pure Matrix. I understand that the idea of logdet is increased accuracy and this is fine, but increased speed is almost never its property, quite the opposite. Consider:

julia> S = Matrix(Diagonal(ones(10_000)));

julia> @btime logdet($S)
  3.829 s (4 allocations: 763.02 MiB)

julia> @btime log(det($S))
  51.544 ms (2 allocations: 78.17 KiB)

That is true for small matrices as well:

julia> @btime logdet(S) setup=(S=[1.0 0.0; 0.0 1.0])
  199.736 ns (2 allocations: 176 bytes)

julia> @btime log(det(S)) setup=(S=[1.0 0.0; 0.0 1.0])
  69.389 ns (1 allocation: 80 bytes)

This pattern repeats for some (not for all, but I also didn't check all of them) of the specialised matrices as well:

julia> @btime logdet(S) setup=(S=Diagonal(ones(1000)))
  4.856 μs (0 allocations: 0 bytes)

julia> @btime log(det(S)) setup=(S=Diagonal(ones(1000)))
  76.674 ns (0 allocations: 0 bytes)

The current documentation reads such that logdet is at least as accurate and fast as log(det), but may be even accurate and/or faster, but in reality it is, very often, slower than the log(det(...)) equivalent. I think its worth documenting this behaviour and noting, that logdet is not always preferable in terms of execution speed.

oscardssmith commented 1 year ago

Is the problem here just that our logdet implementation is bad or that logdet is inherently slower than det?

bvdmitri commented 1 year ago

It could be, but in general it seems reasonable for me that logdet is slower than log(det(...)). logdet is inherently more difficult operation as it ensures that the result does not overflow. In order to achieve this it usually performs some sort of matrix factorization and then computes the determinant in the log-domain. This must be extra work?

Otherwise there would be no need for note at all, because you could define det(S) = exp(logdet(S)) and get equally accurate (without overflow) results in both cases.

bvdmitri commented 1 year ago

Your comment, however, made me try a random pos-def matrix, instead of just Diagonal(ones(N)). The performance differences are negligible, so it might be indeed the problem in the logdet implementation. In the source code I can see that the det function does extra checks before it attempts to perform the lu factorization and logdet tries to do the lu without any checks.

julia> L = rand(10_000, 10_000);

julia> S = L' * L;

julia> @btime det($S)
  4.300 s (4 allocations: 763.02 MiB)

julia> @btime logdet($S)
  4.305 s (4 allocations: 763.02 MiB)
stevengj commented 1 year ago

In order to achieve this it usually performs some sort of matrix factorization and then computes the determinant in the log-domain.

det also generally works by performing a factorization. So, the main difference in principle should be that logdet calls log for each diagonal element of the factorization object and sums them, whereas log(det(A)) multiplies the diagonal elements and calls log once. In principle, the performance penalty of calling log more times should be negligible for large dense matrices, where the cost should be dominated by the cost of factorization.

(Note that we may well be missing some specialized methods of logdet, e.g. for Diagonal.)

But I agree that there should be no performance advantage to logdet in principle. The main point is to avoid overflow/underflow. (For many matrices, log(det(A)) will simply give ±Inf.)

aravindh-krishnamoorthy commented 1 year ago

Hello @bvdmitri, as rightly pointed out by you, there is no inherent advantage in the log(det(.)) path. The performance difference is due to the triangular-matrix-detection routine (which you called checks) in generic.jl's det:

function det(A::AbstractMatrix{T}) where {T}
    if istriu(A) || istril(A)
        S = promote_type(T, typeof((one(T)*zero(T) + zero(T))/one(T)))
        return convert(S, det(UpperTriangular(A)))
    return det(lu(A; check = false))

Now, if we do the same for logabsdet, we get back the same performance. Please consider the following hack-y code, which is set to be typed at the prompt:

function LinearAlgebra.logabsdet(A::AbstractMatrix)
    if istriu(A)
        return logabsdet(UpperTriangular(A))
    elseif istril(A)
        return logabsdet(LowerTriangular(A))
    return logabsdet(lu(A, check=false))

Note: Since upper triangular functions are defined for all types, the above function can be simplified further to only call logabsdet(UpperTriangular(A)), see comment https://github.com/JuliaLang/LinearAlgebra.jl/issues/1017,

With this revised method, we have:

julia> @btime log(det(S))
  33.367 ms (4 allocations: 78.20 KiB)
julia> @btime logdet(S)
  33.648 ms (1 allocation: 16 bytes)

This also explain why you don't see a performance difference for full matrices:

julia> A = rand(1000,1000) + 34*I(1000) ;
julia> @btime log(det(A))
  7.015 ms (5 allocations: 7.64 MiB)
julia> @btime logdet(A)
  7.168 ms (4 allocations: 7.64 MiB)

Note: + 34*I(1000) ensures a positive determinant with a high probability. If you get domain error for log, then please just regenerate matrix A.

aravindh-krishnamoorthy commented 1 year ago

Hello @bvdmitri @stevengj @oscardssmith, if you're happy with this fix (triangular matrix detection), then I can create a PR... To me, this seems to be the obvious fix since since logabsdet is already defined for upper and lower triangular matrices.

stevengj commented 1 year ago

Sure, it seems fine to me — the istriu check is so cheap compared to the lu (even for quite small matrices that are nearly triangular), that the overhead shouldn't be an issue.

(Though it's also easy for users to wrap their matrices in UpperTriangular if they know them to to be triangular — it's not like upper-triangular matrices occur frequently by chance.)

However, I would make the code similar to det by combining the istriu(A) || istril(A) and just wrapping in UpperTriangular (since it only looks at the diagonal elements anyway).