Open valfrom opened 1 week ago
Indeed, we are aware that there are performance cliffs in our convolutions, see e.g. #1313
Thanks for the benchmark though! We will make sure to include the 3D stuff in our optimizations.
FYI, there are a few problems with your benchmark:
Here is an improved version:
def mlx_sample():
x = mx.random.normal([8, 16, 128, 128, 32], dtype=mx.float32)
weight = mx.random.normal([4, 1, 1, 1, 32], dtype=mx.float32)
stride = [1, 1, 1]
padding = [0, 0, 0]
dilation = [1, 1, 1]
# Warmup
for _ in range(5):
out = mx.conv_general(x, weight, stride, padding, dilation, stream=mx.gpu)
mx.eval(out)
start = time.time()
n = 10
for _ in range(n):
out = mx.conv_general(x, weight, stride, padding, dilation, stream=mx.gpu)
mx.eval(out)
print(f'MLX time: {(time.time() - start) * 1000 / n:0.2f}ms')
def torch_sample():
x = torch.randn([8, 32, 16, 128, 128], dtype=torch.float32, device='mps')
weight = torch.randn([4, 32, 1, 1, 1], dtype=torch.float32, device='mps')
bias = torch.randn([4], dtype=torch.float32, device='mps')
stride = [1, 1, 1]
padding = [0, 0, 0]
dilation = [1, 1, 1]
# Warmup
for _ in range(5):
out = torch.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0, 0], 1)
torch.mps.synchronize()
start = time.time()
n = 10
for _ in range(n):
out = torch.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0, 0], 1)
torch.mps.synchronize()
print(f'MPS time: {(time.time() - start) * 1000 / n:0.2f}ms')
Finally, for a 1x1x1 convolution, I'd encourage you to use a Linear layer / matmul. That will be way faster for now. I ran the revised benchmark on an M2 Ultra:
MLX time: 5.26ms
MPS time: 1.38ms
MLX with matmul instead:
MLX time: 0.90ms
That looks like this:
x = mx.random.normal([8, 16 * 128 * 128, 32], dtype=mx.float32)
weight = mx.random.normal([4, 32], dtype=mx.float32)
out = x @ weight.T
Thanks a lot
After some digging, the main issue is channel size of the input and the first dimension of the weight. PyTorch has implemented convolution using native code by utilising MPSGraphConvolution3DOpDescriptor. P.S. convolution with next params can't be calculated at all due to an error "libc++abi: terminating due to uncaught exception of type std::runtime_error: [METAL] Command buffer execution failed: Internal Error (0000000e:Internal Error)" I guess kernel timeout or something like this.
x = mx.random.normal([8, 142, 16, 64, 64], dtype=mx.float32)
weight = mx.random.normal([22, 142, 7, 7, 7], dtype=mx.float32)
bias = mx.random.normal([22], dtype=mx.float32)
stride = [1, 1, 1]
padding = [3, 3, 3]
dilation = [1, 1, 1]
Is there a way to get a metal buffer from an array? Then I'll be probably able to use the native function to calculate 3d convolution.
You have a couple of options:
memoryview(a)
a.__dlpack__()
More info in the docs on converting to other frameworks.
Describe the bug Method mlx.core.conv_general is significantly slower than PyTorch analog. Can vary from 10x to 150x slower.
To Reproduce Just run the attached code.
Include code snippet
Output: MLX time: 20.65ms MPS time: 0.93ms
Expected behavior At least the same speed as in PyTorch.
Desktop (please complete the following information):