ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.37k stars 1.01k forks source link

Fix dispatch threads for a few kernels #1594

Closed awni closed 4 days ago

awni commented 4 days ago

From what i know I think it's pretty much always a performance bug to have a groupdim.* >= grid_dim.*

For example the following benchmark is ~3x faster with this simple fix.

On M1 Max:

Pre: 10.227 ms Post: 3.306 ms

x = mx.random.uniform(shape=(1, 1_000_000, 1))
w = mx.random.uniform(shape=(1,2,1))

def fun():
    return [mx.conv1d(x, w) for _ in range(5)]

timeit(fun)