tinygrad / tinygrad

You like pytorch? You like micrograd? You love tinygrad! ❤️
MIT License
26.54k stars 2.93k forks source link

Match torch speed on M1 for sum reduction #1167

Closed danhipke closed 1 year ago

danhipke commented 1 year ago

Creating an issue (in case others have suggestions) with some of the investigation so far for the bounty to match Torch speed for MPS=1 python3 test/test_speed_v_torch.py TestSpeed.test_sum. First time learning Metal kernels so may be missing something obvious :)

One option raised in Discord is to have 2 reduce kernels (equivalent a.sum(axis=0).sum()). This provides equivalent performance. Based on some investigation, it looks like PyTorch uses 1 kernel, leveraging MPSGraph's reductionSumWithTensor method. Couldn't find how that implements it.

Have done some investigation into doing this in one kernel. Loop unrolling provides 2.5-3x speedup for larger inputs. example:

#include <metal_stdlib>
  using namespace metal;
  kernel void r_256_65536(device float* data0, const device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  threadgroup float temp[256];
    { int lidx0 = lid.x;  /* 1024 */
      float acc0_0 = 0.0f;
      for (int ridx1 = 0; ridx1 <= 16383; ++ridx1) {
        // Unroll the loop
        float val1_0 = data1[((lidx0*16384)+ridx1)];
        float val2_0 = data1[((lidx0*16384)+ridx1+1)];
        float val3_0 = data1[((lidx0*16384)+ridx1+2)];
        float val4_0 = data1[((lidx0*16384)+ridx1+3)];
        acc0_0 = val1_0 + acc0_0;
        acc0_0 = val2_0 + acc0_0;
        acc0_0 = val3_0 + acc0_0;
        acc0_0 = val4_0 + acc0_0;
      } /* reduce */
      temp[lidx0] = acc0_0;
      threadgroup_barrier(mem_flags::mem_threadgroup);
      if (lidx0 == 0) {
        float accm1_0 = 0.0f;
        for (int tidx0 = 0; tidx0 <= 255; ++tidx0) {
          float valm1_0 = temp[tidx0];
          accm1_0 = (valm1_0+accm1_0);
        } /* late_reduce */
        data0[0] = accm1_0;
    }} /* local */
  /* global */

Perf:

Current kernel:

% MPS=1 python3 test/test_speed_v_torch.py TestSpeed.test_sum
 sum                             2048x 2048    0.56 ms (    7.50 GFLOPS    29.99 GB/s) in torch,    1.02 ms (    4.10 GFLOPS    16.42 GB/s) in tinygrad,    1.83x slower       4.19 MOPS    16.78 MB
 sum                             4096x 4096    0.98 ms (   17.21 GFLOPS    68.83 GB/s) in torch,    7.44 ms (    2.26 GFLOPS     9.02 GB/s) in tinygrad,    7.63x slower      16.78 MOPS    67.11 MB
 sum                             8192x 8192    2.19 ms (   30.59 GFLOPS   122.38 GB/s) in torch,   28.91 ms (    2.32 GFLOPS     9.28 GB/s) in tinygrad,   13.18x slower      67.11 MOPS   268.44 MB
.
----------------------------------------------------------------------
Ran 1 test in 1.209s

Loop unrolled kernel:

% MPS=1 python3 test/test_speed_v_torch.py TestSpeed.test_sum

 sum                             2048x 2048    0.55 ms (    7.58 GFLOPS    30.31 GB/s) in torch,    1.04 ms (    4.02 GFLOPS    16.08 GB/s) in tinygrad,    1.88x slower       4.19 MOPS    16.78 MB
 sum                             4096x 4096    0.99 ms (   17.03 GFLOPS    68.11 GB/s) in torch,    2.82 ms (    5.95 GFLOPS    23.82 GB/s) in tinygrad,    2.86x slower      16.78 MOPS    67.11 MB
 sum                             8192x 8192    2.19 ms (   30.68 GFLOPS   122.73 GB/s) in torch,   11.14 ms (    6.02 GFLOPS    24.10 GB/s) in tinygrad,    5.09x slower      67.11 MOPS   268.44 MB
.
----------------------------------------------------------------------
Ran 1 test in 1.042s

Still investigating as unrolling doesn't really matter for smaller Tensors - this has been a pretty useful reference with some other approaches to try. Apple's docs also list a reduce example using SIMD functions is we want to get more sophisticated.

geohot commented 1 year ago

How do you know PyTorch has only one kernel? What's the launch params? Can you disassemble it? Check out: https://github.com/dougallj/applegpu

danhipke commented 1 year ago

Thanks for the pointer. I was using the MPS profiler in pytorch (available in nightly), but it looks like it actually only outputs information at the graph level (there is only 1 reduce graph), not the actual kernels. Will look to see if I can disassemble the kernel.

% MPS=1 PYTORCH_MPS_LOG_PROFILE_INFO=31 python3 test_speed_v_torch.py TestSpeed.test_sum

BlitCopySync: CPU:Float[4096, 4096] --> MPS:Float[4096, 4096] (len=64.00 MB)
BlitCopySync: MPS:Float[4096, 4096] --> CPU:Float[4096, 4096] (len=64.00 MB)
aten::add_out_mps::f32[4096,4096]:i64[Scalar]:f32[4096,4096] (id=G1, run=1)
BlitCopySync: MPS:Float[4096, 4096] --> CPU:Float[4096, 4096] (len=64.00 MB)
aten::sum_out_mps:0,1::f32[4096,4096]:0:4::f32[Scalar]: (id=G2, run=1)
BlitCopySync: MPS:Float[] --> CPU:Float[] (len=4 bytes)
 sum                             4096x 4096    4.50 ms (    3.73 GFLOPS    14.91 GB/s) in torch,    8.52 ms (    1.97 GFLOPS     7.88 GB/s) in tinygrad,    1.89x slower      16.78 MOPS    67.11 MB
.
----------------------------------------------------------------------
Ran 1 test in 0.188s

OK

---------------------------------------------------------------------------- MPS Operations Profiling: 2 graphs, 0 kernels -----------------------------------------------------------------------------
  ID  | #Runs | Mean KRNL(ms) | Mean GPU(ms) | Total GPU(ms) | Operation Name
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
  G1      1         21.946          2.600          2.600       add_out_mps::f32[4096,4096]:i64[Scalar]:f32[4096,4096]
  G2      1         0.716           0.319          0.319       sum_out_mps:0,1::f32[4096,4096]:0:4::f32[Scalar]:
There are no CPU Fallbacks logged for profiling

----------------------------------------------- MPS Copy Profiling: 4 total copies (192.00 MB), 1 scalar copies ------------------------------------------------
    Kind    |  Total#  |   Total Size    | Total KRNL(ms) | Total GPU(ms) | Scalars | Scalars GPU | Blocking | memcpy
----------------------------------------------------------------------------------------------------------------------------------------------------------------
 CPU to MPS       1          64.00 MB          51.775           0.618          0         0.00 %         1         0
 MPS to CPU       3         128.00 MB          23.473           2.977          1         0.12 %         3         0