ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.01k stars 855 forks source link

Add groups to Conv1d #948

Closed Rifur13 closed 1 month ago

Rifur13 commented 2 months ago

Proposed changes

Adding groups to 1D convolutions. Resolves #237.

Checklist

Put an x in the boxes that apply.

Rifur13 commented 2 months ago

Wdyt? This is for CPU only. The GPU code should be very similar to this so I want to get some feedback before I continue.

Main changes:

awni commented 2 months ago

@Rifur13 this looks cool! Do you intend to add the GPU kernel here? Also this will just be for 1D grouped convolutions, correct?

Also would be great to if you can run some benchmarks:

Rifur13 commented 1 month ago

Yep I intend to add the GPU kernel as well. And yes, this PR will focus on 1D convolutions only.

Benchmarks coming soon!

Rifur13 commented 1 month ago

Performance doesn’t look great, it scales worse with more groups.



(N, iH, C) (O, wH, C) dtype stride pads groups diff%
(4, 32, 32) (32, 5, 32) float32 1 2 1 +179.77%
(4, 32, 32) (32, 5, 32) float32 1 2 2 +59.62%
(4, 32, 32) (32, 5, 32) float32 1 2 4 +33.96%
(4, 32, 32) (32, 5, 32) float32 1 2 8 +0.71%
(4, 32, 32) (32, 5, 32) float32 1 2 8 +15.49%
(4, 32, 32) (32, 5, 32) float32 1 2 16 -32.81%
(4, 32, 32) (32, 5, 32) float32 1 2 32 -62.36%
(4, 32, 256) (512, 5, 256) float32 1 2 2 +41.59%
(4, 32, 256) (512, 5, 256) float32 1 2 128 -88.60%
(4, 32, 256) (512, 5, 256) float32 1 2 256 -93.96%

What we really need is a specialized steel_matmul that splits up the inputs into groups and dispatches the kernels in parallel.
It might take me a while to understand all the gemm kernel code. I’m not sure how much time I’ll have so if something really needs it they can take up this work.

It would be good to have some working version in the meantime to unblock people (like me).

Rifur13 commented 1 month ago

I’ll take another look actually. If I ignore the split k specialization this seems very doable.

awni commented 1 month ago

Just curious what is the last column measuring? It's a difference from what to what exactly? CPU -> GPU?

Rifur13 commented 1 month ago

No it's actually mlx vs pytorch. They should scale similarly so I use these numbers to measure performance.

Also small update: I'm trying to parallelize the groups for loop by sending each kernel to a different command buffer. So I will create groups streams, groups command queues, etc... Working through some errors right now, but lmk if that makes sense

awni commented 1 month ago

Also small update: I'm trying to parallelize the groups for loop by sending each kernel to a different command buffer. So I will create groups streams, groups command queues, etc... Working through some errors right now, but lmk if that makes sense

Actually, I would not do that. That is going to introduce a lot of overhead and subvert how we do job submission for the GPU.

The best approach is to have a single kernel to do all the groups and handle that extra dimension in the thread grid or something like that. But I realize that might be a lot more work.

A less good option that you could try is to use a concurrent command encoder. If you rebase on main, you will get some functionality to make that much easier.

awni commented 1 month ago

Here is a very simple example of how we do that in concatenate now: https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/primitives.cpp#L556-L564

Rifur13 commented 1 month ago

Thanks for guiding me in the right direction! Numbers looks very good now and it’s review for review.

N iH C O wH C dtype stride pads groups diff%
4 32 32 32 5 32 float32 1 2 1 +189.00%
4 32 32 32 5 32 float32 1 2 2 +176.95%
4 32 32 32 5 32 float32 1 2 4 +185.48%
4 32 32 32 5 32 float32 1 2 8 +183.16%
4 32 32 32 5 32 float32 1 2 8 +181.10%
4 32 32 32 5 32 float32 1 2 16 +145.79%
4 32 32 32 5 32 float32 1 2 32 +102.98%
4 32 256 512 5 256 float32 1 2 2 +110.27%
4 32 256 512 5 256 float32 1 2 128 +50.08%
4 32 256 512 5 256 float32 1 2 256 +28.68%
awni commented 1 month ago

Very nice result!! Will review soon.

awni commented 1 month ago

@Rifur13 did you do any benchmarking for the CPU version? It's not a super high priority to make it fast, but we also don't want to make it worse than it was

Rifur13 commented 1 month ago

There’s an extra copy so in theory it should be worse but I didn’t see a noticeable difference in my tests. The code for convolutions when groups = 1 is unchanged now, so the performance is identical as before.

I refactored the code to remove this copy and I think it also looks a lot cleaner. It’s easier to understand the code for groups vs without groups.

Rifur13 commented 1 month ago

Any notes or concerns?

awni commented 1 month ago

Not really on my side. I think we can merge this, results are very nice and code looks good!! @jagrit06 or @angeloskath do either of you care to take a quick look?

Rifur13 commented 1 month ago

@jagrit06 Good catch I’ll add a comment for the jvp.

The existing gemm kernel using the 3rd grid dim as the batch size. Are you suggesting to repurpose batches as groups? Readability would take a hit imo.

I think it’s possible if we set:

params->batch_ndim = 1
params->batch_stride_a = K
params->batch_stride_b = N * K
params->batch_stride_d = N
jagrit06 commented 1 month ago

@jagrit06 Good catch I’ll add a comment for the jvp.

The existing gemm kernel using the 3rd grid dim as the batch size. Are you suggesting to repurpose batches as groups? Readability would take a hit imo.

I think it’s possible if we set:

params->batch_ndim = 1
params->batch_stride_a = K
params->batch_stride_b = N * K
params->batch_stride_d = N

Exactly as you suggest, we can set the batch strides to let the tid.z handle that I don't particularly think this is a bad enough readability hit for us to include the overhead of compiling and packing all new sets of gemm kernels which are basically the same as the ones we already have

Thanks!

Rifur13 commented 1 month ago

Done! Thanks for all the suggestions.

Ready for a final review.

awni commented 1 month ago

@Rifur13 the conv 1d test failed. Do you mind checking it?

Rifur13 commented 1 month ago

Tests should pass now. Tricky one..

awni commented 1 month ago

It's failling metal validation. You should be able to reproduce locally with:

METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python ..

Rifur13 commented 1 month ago

Fixed! Probably a good idea to add these test options in the docs somewhere

awni commented 1 month ago

@Rifur13 sorry for the delay in merging this caused a conflict. If you can fix it we can merge asap. I also don't mind fixing the conflict tomorrow sometime.

Rifur13 commented 1 month ago

Rebased. Should be fixed now