Closed ghostplant closed 2 months ago
The memory bandwidth is 100GB/s. It is unified memory for cpu and gpu. 0.8TFlops might be fast enough. It's hard to reach 100% with matmul.
I'll be working on a small module to sweep tiling + workgroup parameters in the near future (have to play with them manually in the meantime) to automate environment-specific tuning. Those parameters haven't been tweaked for an M2 air.
That said 887/3600 ~ 25% theoretical maximum is around the same ratio i'm getting with my M1 (~ 2.5 tflops / 10.4 tflops theoretical max) so it's not far off from what we're seeing. Would be good to see what we get for numpy / pytorch on apple silicon so we have some sense of what's a reasonable performance performance to expect and how much more headroom we have to improve.
Using Pytorch MPS, Apple M2 can get 2.9TFlops on the same GEMM size
Good to know - we have more room for improvement on that example then. One reference implementation we should check is the tensorflow js wgsl implementation which is relatively optimized.
It is indeed a compute bound, so it may be possible to reach more than 2.9 TFlops. Thx! https://jax.readthedocs.io/en/latest/pallas/tpu/matmul.html
def matmul_flops(m, k, n):
return 2 * m * k * n
def matmul_membw(m, k, n):
return (m * k + k * n + m * n) * 4
def matmul_flops_intensity(m, k, n):
flops = matmul_flops(m, k, n)
membw = matmul_membw(m, k, n)
return flops / membw
# 3.6TFlops
m2_flops = 3.6e12
# 100GB/s
m2_membw = 100e9
# flops / byte
m2_op_intensity = m2_flops / m2_membw
print(f"m2_op_intensity: {m2_op_intensity} flops/byte")
#m2_op_intensity: 36.0 flops/byte
print(f"matmul_op_intensity: {matmul_flops_intensity(4096, 4096, 8192)} flops/byte")
#matmul_op_intensity: 819.2 flops/byte
# m2_op_intensity(36.0 flops/byte) is less than matmul_op_intensity(819.2 flops/byte).
# It is compute bound!
Where is the matmul kernel of mps? This one is the naive kernel. (It's not fast.) https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/mps/operations/LinearAlgebra.mm#L32-L81
I get 2.9TFlops using this Torch code:
#!/usr/bin/env python3
import os, sys
import argparse
import torch
import time
X = torch.arange(4096 * 4096, dtype=torch.float32).view([4096, 4096]).to('mps')
Y = torch.arange(4096 * 4096, dtype=torch.float32).view([4096, 4096]).to('mps')
def wait():
torch.mps.synchronize()
return time.perf_counter()
torch.matmul(X, Y)
torch.matmul(X, Y)
torch.matmul(X, Y)
t0 = wait()
for i in range(10):
torch.matmul(X, Y)
t1 = wait()
cost = (t1 - t0) / 10
print('TFlops:', 4096 * 4096 * 4096 * 2 / cost * 1e-12)
@ghostplant Thx! It is float32.
I checked the tests with arange and randn.
I will look into why the later version is downgraded.
The test with arange gets 5.7 TFlops on m2 pro. The test with randn also gets 5.7 TFlops on m2 pro.
The test with randn gets 1.9TFlops on m2 pro.
$ cat test_arange.py
#!/usr/bin/env python3
import os, sys
import argparse
import torch
import time
X = torch.arange(4096 * 4096, dtype=torch.float32).view([4096, 4096]).to('mps')
Y = torch.arange(4096 * 8192, dtype=torch.float32).view([4096, 8192]).to('mps')
def wait():
torch.mps.synchronize()
return time.perf_counter()
torch.matmul(X, Y)
torch.matmul(X, Y)
torch.matmul(X, Y)
t0 = wait()
for i in range(10):
torch.matmul(X, Y)
t1 = wait()
cost = (t1 - t0) / 10
$ python test_arange.py
TFlops: 5.726274617567384
$ cat test_randn.py
#!/usr/bin/env python3
import os, sys
import argparse
import torch
import time
X = torch.randn((4096, 4096), requires_grad=False, dtype=torch.float32).to("mps")
Y = torch.randn((4096, 8192), requires_grad=False, dtype=torch.float32).to("mps")
Z = torch.randn((4096, 8192), requires_grad=False, dtype=torch.float32).to("mps")
def wait():
torch.mps.synchronize()
return time.perf_counter()
with torch.no_grad():
torch.matmul(X, Y)
torch.matmul(X, Y)
torch.matmul(X, Y)
t0 = wait()
niter = 30
for i in range(niter):
Z=torch.matmul(X, Y)
t1 = wait()
cost = (t1 - t0) / niter
print('TFlops:', 4096 * 4096 * 8192 * 2 / cost * 1e-12)
$ python test_randn.py
TFlops: 5.760410971906142
They should have some perf. I found you should / 10
instead of / 30
in the second source code.
@ghostplant Thank you! I have updated the results above. Both are the same.
Data dependence of matmul flops has been documented:
https://www.thonking.ai/p/strangely-matrix-multiplications
This might be the first time i've seen the behavior discussed for apple silicon though (in the 30 iteration case).
We'll tackle + track this under @junjihashimoto 's ongoing work on task 14 here: https://github.com/orgs/AnswerDotAI/projects/5/
Closing for now but feel free to post follow-ups here or discuss further in discord https://discord.gg/zmJVhXsC7f
Pytorch's matmul is matrixMultiplicationWithPrimaryTensor
.
https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/matrixmultiplication(primary:secondary:name:)?changes=_1_7&language=objc
https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/mps/operations/LinearAlgebra.mm#L120
It is not OSS.
Machine: Macbook Air M2
Theoretical GPU TFlops: 3.6TFlops
Actual TFlops: