sarah-quinones / gemm

MIT License
76 stars 12 forks source link

This improves drastically overthreading issue (>48cores) #11

Open Narsil opened 1 year ago

Narsil commented 1 year ago

I'm not sure that this change is optimal by any means.

But it does yield a significant improvement when running relatively small matmul over a 48 core machine.

Before:

// 48 cores
parallelism-48-f32-nnn-gemm-6×2304×768
                        time:   [2.2215 ms 2.2584 ms 2.2906 ms]
                        change: [-2.7095% -0.4486% +2.0755%] (p = 0.74 > 0.05)
                        No change in performance detected.
Found 1 outliers among 10 measurements (10.00%)
  1 (10.00%) high mild

parallelism-none-f32-nnn-gemm-6×2304×768
                        time:   [745.65 µs 746.78 µs 748.09 µs]
                        change: [-0.6916% -0.4244% -0.1303%] (p = 0.02 < 0.05)
                        Change within noise threshold.

After:

parallelism-48-f32-nnn-gemm-6×2304×768
                        time:   [641.83 µs 651.66 µs 664.90 µs]
                        change: [-71.903% -71.301% -70.685%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 2 outliers among 10 measurements (20.00%)
  1 (10.00%) high mild
  1 (10.00%) high severe

parallelism-none-f32-nnn-gemm-6×2304×768
                        time:   [741.77 µs 744.47 µs 748.39 µs]
                        change: [-0.9774% -0.6209% -0.2174%] (p = 0.01 < 0.05)
                        Change within noise threshold.
Found 1 outliers among 10 measurements (10.00%)
  1 (10.00%) high severe

At least we're not slowing down drastically (but this is not an improvement either)

sarah-quinones commented 1 year ago

this formula here seems pretty cryptic, is there some reasoning behind it?

 let n_threads = std::cmp::max(1, std::cmp::min(max_threads, (total_work - threading_threshold + 1) / threading_threshold));
Narsil commented 1 year ago

threading_threshold is what you had before to get num_threads=1 vs num_thread=all.

(total_work - threading_threshold + 1) / threading_threshold Is simply ceil(total_work/threading_threshold) (To get a heuristic on how many threads this looks ok to share. min(max_threads, X) is to not use more threads than requested max(1, X) is to use at least 1.