JuliaSparse / SparseArrays.jl

SparseArrays.jl is a Julia stdlib
92 stars 52 forks source link

`findmin(A; dims=1)` is much slower than manually looping over. #510

Closed pratyai closed 9 months ago

pratyai commented 9 months ago

Some demo:

using SparseArrays

# setups
n = 10000
halfhalf = a -> SparseMatrixCSC(a.m, a.n, a.colptr, a.rowval, a.nzval .- 0.5);
symmetrize = a -> (a + a')/2;
A = halfhalf(symmetrize(sprand(n, n, 0.1)));  # nnz(A) == 19004033
function manualmincol(a)
   local mincols = zeros(Int, a.m)
   local minvals = zeros(Float64, a.m)
   for c in 1:a.m
       local rb, re = a.colptr[c], a.colptr[c+1]
       if rb == re
       local minval::Float64, row_::Int64 = Inf, 0
       for r in rb:(re-1)
           local val, row = a.nzval[r], a.rowval[r]
           if val < minval
               minval, row_ = val, row
       mincols[c], minvals[c] = row_, minval
   return mincols, minvals

# test
@time bob = manualmincol(A)[1];
# output: 0.138503 seconds (5 allocations: 156.375 KiB)

@time bob2 = findmin(A, dims=1)[2];
# output: 2.414430 seconds (9 allocations: 391.094 KiB)

# check that they are the same
bob2 = Vector(first.(Tuple.(bob2))'[:,1]);
bob2 == bob
# output: true

I believe the reason is that with dims=... it goes through a non-specialised version of findmin():

julia> @which findmin(A, dims=1)
kwcall(::NamedTuple, ::typeof(findmin), A::AbstractArray)
     @ Base reducedim.jl:1130

Does it make sense to "intercept" the dims argument and provide a faster implementation? (I understand that the performance for row and column aggregations would be different, but currently both are much slower than spelling out the loops)

dkarrasch commented 9 months ago

You may need to overload the three-arg function, _findmin(f, A, dims), but otherwise: yes, please, make a PR.