ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.01k stars 856 forks source link

Add groups to 2-D convolutions #1129

Closed Rifur13 closed 1 week ago

Rifur13 commented 2 weeks ago

Proposed changes

Added groups to 2-D convolutions for some kernel specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

Can close out #100

Performance looks pretty good: (N, H, W, C) (O, kH, kW, C) dtype stride pads groups diff%
(4, 64, 64, 256) (256, 5, 5, 256) float32 (1, 1) (2, 2) 1 +25.78%
(4, 64, 64, 256) (256, 5, 5, 256) float32 (1, 1) (2, 2) 2 +36.72%
(4, 64, 64, 256) (256, 5, 5, 256) float32 (1, 1) (2, 2) 16 -23.80%
(4, 64, 64, 256) (256, 5, 5, 256) float32 (1, 1) (2, 2) 64 +92.72%

Checklist

Put an x in the boxes that apply.