Closed awni closed 4 days ago
From what i know I think it's pretty much always a performance bug to have a groupdim.* >= grid_dim.*
groupdim.* >= grid_dim.*
For example the following benchmark is ~3x faster with this simple fix.
On M1 Max:
Pre: 10.227 ms Post: 3.306 ms
x = mx.random.uniform(shape=(1, 1_000_000, 1)) w = mx.random.uniform(shape=(1,2,1)) def fun(): return [mx.conv1d(x, w) for _ in range(5)] timeit(fun)
From what i know I think it's pretty much always a performance bug to have a
groupdim.* >= grid_dim.*
For example the following benchmark is ~3x faster with this simple fix.
On M1 Max:
Pre: 10.227 ms Post: 3.306 ms