ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.47k stars 943 forks source link

[Performance] mlx.core.conv_general is really slow #1409

Open valfrom opened 1 week ago

valfrom commented 1 week ago

Describe the bug Method mlx.core.conv_general is significantly slower than PyTorch analog. Can vary from 10x to 150x slower.

To Reproduce Just run the attached code.

Include code snippet

import mlx.core as mx
import time
import torch

def mlx_sample():
    x = mx.random.normal([8, 16, 128, 128, 32], dtype=mx.float32)
    weight = mx.random.normal([4, 1, 1, 1, 32], dtype=mx.float32)
    stride = [1, 1, 1]
    padding = [0, 0, 0]
    dilation = [1, 1, 1]
    start = time.time()
    n = 10
    for _ in range(n):
        out = mx.conv_general(x, weight, stride, padding, dilation, stream=mx.gpu)
        mx.eval(out)

    print(f'MLX time: {(time.time() - start) * 1000 / n:0.2f}ms')

def torch_sample():
    x = torch.randn([8, 32, 16, 128, 128], dtype=torch.float32, device='mps')
    weight = torch.randn([4, 32, 1, 1, 1], dtype=torch.float32, device='mps')
    bias = torch.randn([4], dtype=torch.float32, device='mps')
    stride = [1, 1, 1]
    padding = [0, 0, 0]
    dilation = [1, 1, 1]
    start = time.time()
    n = 10
    for _ in range(n):
        out = torch.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0, 0], 1)
        out.max()

    print(f'MPS time: {(time.time() - start) * 1000 / n:0.2f}ms')

def main():
    mlx_sample()
    torch_sample()

if __name__ == '__main__':
    main()

Output: MLX time: 20.65ms MPS time: 0.93ms

Expected behavior At least the same speed as in PyTorch.

Desktop (please complete the following information):

awni commented 1 week ago

Indeed, we are aware that there are performance cliffs in our convolutions, see e.g. #1313

Thanks for the benchmark though! We will make sure to include the 3D stuff in our optimizations.

FYI, there are a few problems with your benchmark:

Here is an improved version:

def mlx_sample():
    x = mx.random.normal([8, 16, 128, 128, 32], dtype=mx.float32)
    weight = mx.random.normal([4, 1, 1, 1, 32], dtype=mx.float32)
    stride = [1, 1, 1]
    padding = [0, 0, 0]
    dilation = [1, 1, 1]

    # Warmup
    for _ in range(5):
        out = mx.conv_general(x, weight, stride, padding, dilation, stream=mx.gpu)
        mx.eval(out)

    start = time.time()
    n = 10
    for _ in range(n):
        out = mx.conv_general(x, weight, stride, padding, dilation, stream=mx.gpu)
        mx.eval(out)
    print(f'MLX time: {(time.time() - start) * 1000 / n:0.2f}ms')

def torch_sample():
    x = torch.randn([8, 32, 16, 128, 128], dtype=torch.float32, device='mps')
    weight = torch.randn([4, 32, 1, 1, 1], dtype=torch.float32, device='mps')
    bias = torch.randn([4], dtype=torch.float32, device='mps')
    stride = [1, 1, 1]
    padding = [0, 0, 0]
    dilation = [1, 1, 1]
    # Warmup
    for _ in range(5):
        out = torch.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0, 0], 1)
        torch.mps.synchronize()

    start = time.time()
    n = 10
    for _ in range(n):
        out = torch.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0, 0], 1)
        torch.mps.synchronize()

    print(f'MPS time: {(time.time() - start) * 1000 / n:0.2f}ms')

Finally, for a 1x1x1 convolution, I'd encourage you to use a Linear layer / matmul. That will be way faster for now. I ran the revised benchmark on an M2 Ultra:

MLX time: 5.26ms
MPS time: 1.38ms

MLX with matmul instead:

MLX time: 0.90ms

That looks like this:

    x = mx.random.normal([8, 16 * 128 * 128, 32], dtype=mx.float32)
    weight = mx.random.normal([4, 32], dtype=mx.float32)
    out = x @ weight.T
valfrom commented 1 week ago

Thanks a lot

valfrom commented 5 days ago

After some digging, the main issue is channel size of the input and the first dimension of the weight. PyTorch has implemented convolution using native code by utilising MPSGraphConvolution3DOpDescriptor. P.S. convolution with next params can't be calculated at all due to an error "libc++abi: terminating due to uncaught exception of type std::runtime_error: [METAL] Command buffer execution failed: Internal Error (0000000e:Internal Error)" I guess kernel timeout or something like this.

    x = mx.random.normal([8, 142, 16, 64, 64], dtype=mx.float32)
    weight = mx.random.normal([22, 142, 7, 7, 7], dtype=mx.float32)
    bias = mx.random.normal([22], dtype=mx.float32)
    stride = [1, 1, 1]
    padding = [3, 3, 3]
    dilation = [1, 1, 1]

Is there a way to get a metal buffer from an array? Then I'll be probably able to use the native function to calculate 3d convolution.

awni commented 4 days ago

You have a couple of options:

More info in the docs on converting to other frameworks.