willow-ahrens / finch-tensor

Sparse Tensors in Python and more! Datastructure-Driven Array Programming Language
MIT License
8 stars 3 forks source link

`finch-tensor` performance #24

Open mtsokol opened 6 months ago

mtsokol commented 6 months ago

Hi @willow-ahrens @hameerabbasi,

I managed to run first simple benchmarks in pydata/sparse with PyData and Finch backends. asv tool was used for benchmarking which runs a "setup" function to execute the compilation step and then calls benchmarked function a few times, and takes a median.

First thing for the investigation is eager and lazy modes results. For a function:

def my_custom_fun(arr1, arr2, arr3):
    temp = finch.multiply(arr1, arr2)
    temp = finch.divide(temp, arr3)
    reduced = finch.sum(temp, axis=(0, 1))
    return finch.add(temp, reduced)

With a random CSF input 100x100x100, density 0.01, the benchmark results were:

I used py-spy profiler to inspect the lazy mode and for each compiled function call, all of the time was spent on the Julia side (juliacall and Finch):

py-spy

I also initially benchmarked element-wise functions, +/* of two random 100x100x100 CSF tensors (eager mode only):


Comments

willow-ahrens commented 6 months ago

When we benchmark the higher-level Finch functions, two questions need answering:

  1. What is the overhead of calling the function? We could answer this by benchmarking on 1x1 matrices.
  2. What is the performance of the function ignoring the overhead of calling? We could answer this by benchmarking on 1_000_000 x 1_000_000 matrices with 0.001 or 0.01 nonzeros. If this call hangs and does not finish, we might accidentally be generating dense code somewhere, and should investigate that. Refer to the optimization tips for potential pitfalls, and try printing out what finch calls are getting generated.

Once we have determined what the calling overhead is and we have ruled out performance gotchas, we can talk about reducing both of them. We can improve scheduling heuristics (improving loop ordering to reduce transposing and improve filtering, reduce the number of arguments we process at once to reduce code bloat, etc.), reduce the runtime of the scheduler (avoid fixpoint rewriting, ensure more static typing, etc.), cache scheduled programs, etc. I'll also link to https://github.com/willow-ahrens/Finch.jl/issues/460, which points out that the calling overhead of scheduling anything appears to be 200 ms.

Also, for input formats, I would go for CSF in general.

rgommers commented 6 months ago

If it helps, here is some code that could be reused for benchmark plots as a function of input array size: https://github.com/rgommers/explore-array-computing/blob/master/explore_xnd.ipynb. I quite like such plots, since they show in one glance the function call overhead and the scaling properties.

mtsokol commented 4 months ago

Hi @willow-ahrens,

Here's comparison of SDDMM written with @finch macro that you shared that executes for me in 422.570 μs, compared to lazy API implementation that is used by Python wrapper (executes in 380.194 ms).

Below I also share debug mode logs of executed plan.

using Finch
using BenchmarkTools

LEN = 1000;
DENSITY = 0.00001;

s = fsprand(LEN, LEN, DENSITY);
a = rand(LEN, LEN);
b = rand(LEN, LEN);

@benchmark begin
   s = $s
   a = $a
   b = $b
   c = Tensor(Dense(Sparse(Element(0.0))))
   @finch mode=:fast begin
      c .= 0
      for k = _, j = _, i = _
         c[i, j] += s[i, j] * a[i, k] * b[k, j]
      end
   end
end

result: Time (mean ± σ): 422.570 μs

using Finch
using BenchmarkTools

LEN = 1000;
DENSITY = 0.00001;

s = lazy(fsprand(LEN, LEN, DENSITY));
a = lazy(Tensor(rand(LEN, LEN)));
b = lazy(Tensor(rand(LEN, LEN)));

c = tensordot(a, b, ((2,), (1,)));
plan = permutedims(broadcast(.*, permutedims(s, (1, 2)), permutedims(c, (1, 2))), (1, 2));

res = compute(plan);

@benchmark begin
   compute(plan);
end

result: Time (mean ± σ): 380.194 ms

Executing:
:(function var"##compute#435"(prgm)
      begin

          V = ((((((((((((prgm.children[1]).children[2]).children[1]).children[1]).children[2]).children[1]).children[2]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{SparseCOOLevel{2, Tuple{Int64, Int64}, Vector{Int64}, Tuple{Vector{Int64}, Vector{Int64}}, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}

          V_2 = ((((((((((((((((prgm.children[1]).children[2]).children[1]).children[1]).children[2]).children[1]).children[3]).children[1]).children[1]).children[1]).children[1]).children[2]).children[3]).children[2]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}

          V_3 = ((((((((((((((((prgm.children[1]).children[2]).children[1]).children[1]).children[2]).children[1]).children[3]).children[1]).children[1]).children[1]).children[1]).children[2]).children[3]).children[3]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}

          A0 = V::Tensor{SparseCOOLevel{2, Tuple{Int64, Int64}, Vector{Int64}, Tuple{Vector{Int64}, Vector{Int64}}, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}

          A2 = V_2::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}

          A4 = V_3::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}

          A6 = Tensor(Dense(Dense(Element{0.0, Float64}())))::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}

          @finch mode = :fast begin
                  A6 .= 0.0
                  for i8 = _
                      for i7 = _
                          for i6 = _
                              A6[i6, i8] << + >>= (*)(A2[i6, i7], A4[i7, i8])
                          end
                      end
                  end
                  return A6
              end
          A9 = Tensor(Dense(Dense(Element{0.0, Float64}())))::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}
          @finch mode = :fast begin
                  A9 .= 0.0
                  for i16 = _
                      for i15 = _
                          A9[i15, i16] << Finch.FinchNotation.InitWriter{0.0}() >>= (Base.Broadcast.BroadcastFunction(*))(A0[i15, i16], A6[i15, i16])
                      end
                  end
                  return A9
              end
          return (A9,)
      end
  end)
willow-ahrens commented 4 months ago

this might just be the cost of unfused sddmm without accelerating the dense ops. consider:

plan = @einsum c[i, j] += s[i, j] * a[i, k] * b[j, k]
Finch.LazyTensor{Float64, 2}(reorder(aggregate(immediate(+), immediate(0.0), mapjoin(immediate(*), relabel(subquery(alias(##A#294), table(immediate(Tensor(SparseCOO{2, Tuple{Int64, Int64}}(Element{0.0, Float64, Int64}([0.4668361195926698, 0.3738177102725496, 0.6274468897160046, 0.09205289146977969, 0.013701297267564638, 0.7743841318738852, 0.7824017428334505, 0.2510431893070184, 0.8659441791313686]), (1000, 1000), [1, 10], ([454, 871, 273, 430, 976, 34, 500, 664, 428], [17, 139, 188, 530, 624, 652, 686, 849, 902], ) ))), field(##i#295), field(##i#296))), field(i), field(j)), relabel(subquery(alias(##A#299), table(immediate(Tensor(Dense{Int64}(Dense{Int64}(Element{0.0, Float64, Int64}([0.5061070271828509, 0.5077277183281763, 0.1695208843167083, 0.969232997007124, 0.5675902948439756, 0.9125291341327263, 0.905000953438578, 0.6400938674316544, 0.7896219582360124, 0.021845036421344943  …  0.43818025034958374, 0.08346418915761966, 0.16075072389802147, 0.6305985096585706, 0.9156726308879686, 0.19945992761863263, 0.4245794957333421, 0.030785182655595378, 0.11072673846944725, 0.686449847771958]), 1000), 1000))), field(##i#300), field(##i#301))), field(i), field(k)), relabel(subquery(alias(##A#302), table(immediate(Tensor(Dense{Int64}(Dense{Int64}(Element{0.0, Float64, Int64}([0.36625072983125007, 0.6598399666935976, 0.2030795750942428, 0.4045212153064739, 0.919315060167964, 0.4828564852471168, 0.8035645454194987, 0.9627669653607265, 0.35104263118224643, 0.335849216307714  …  0.954437248319078, 0.14139334802736392, 0.4905796420786366, 0.9859451809082922, 0.7789293193672848, 0.07271594640593604, 0.3409324773807292, 0.48940065061469296, 0.7890727822067549, 0.8348185755546461]), 1000), 1000))), field(##i#303), field(##i#304))), field(j), field(k))), field(k)), field(i), field(j)), (false, false), 0.0)

julia> @benchmark begin
          compute($plan);
       end
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  129.125 μs …  2.705 ms  ┊ GC (min … max): 0.00% … 94.05%
 Time  (median):     141.416 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   143.293 μs ± 42.957 μs  ┊ GC (mean ± σ):  0.50% ±  1.62%

  ▁▃▃▁   ▁         ▂▇█▇▅▃▂▅▄▃▂▁ ▁▁▁▁▁▁ ▁▁▁                     ▂
  ████▇▆████▇▆▆▆▇▇▇███████████████████▇████▇▇▇█▇▆▅▆▇▇▇▆▆▅▆▆▄▆▅ █
  129 μs        Histogram: log(frequency) by time       166 μs <

 Memory estimate: 17.62 KiB, allocs estimate: 264.
willow-ahrens commented 4 months ago

you can see in the logs that the expression is not fused, so we should fix the bugs that allow us to fuse the plan.

willow-ahrens commented 4 months ago

a side note: there are two ways to broadcast in julia. You can do

a .* b

or you can do

broadcast(*, a, b)

however, when you do

broadcast(.*, a, b)

you're broadcasting the broadcast multiplication operator, and I'm not sure what finch will think of that.