Closed Rifur13 closed 1 month ago
Wdyt? This is for CPU only. The GPU code should be very similar to this so I want to get some feedback before I continue.
Main changes:
@Rifur13 this looks cool! Do you intend to add the GPU kernel here? Also this will just be for 1D grouped convolutions, correct?
Also would be great to if you can run some benchmarks:
Yep I intend to add the GPU kernel as well. And yes, this PR will focus on 1D convolutions only.
Benchmarks coming soon!
Performance doesn’t look great, it scales worse with more groups.
(N, iH, C) | (O, wH, C) | dtype | stride | pads | groups | diff% |
---|---|---|---|---|---|---|
(4, 32, 32) | (32, 5, 32) | float32 | 1 | 2 | 1 | +179.77% |
(4, 32, 32) | (32, 5, 32) | float32 | 1 | 2 | 2 | +59.62% |
(4, 32, 32) | (32, 5, 32) | float32 | 1 | 2 | 4 | +33.96% |
(4, 32, 32) | (32, 5, 32) | float32 | 1 | 2 | 8 | +0.71% |
(4, 32, 32) | (32, 5, 32) | float32 | 1 | 2 | 8 | +15.49% |
(4, 32, 32) | (32, 5, 32) | float32 | 1 | 2 | 16 | -32.81% |
(4, 32, 32) | (32, 5, 32) | float32 | 1 | 2 | 32 | -62.36% |
(4, 32, 256) | (512, 5, 256) | float32 | 1 | 2 | 2 | +41.59% |
(4, 32, 256) | (512, 5, 256) | float32 | 1 | 2 | 128 | -88.60% |
(4, 32, 256) | (512, 5, 256) | float32 | 1 | 2 | 256 | -93.96% |
What we really need is a specialized steel_matmul
that splits up the inputs into groups and dispatches the kernels in parallel.
It might take me a while to understand all the gemm kernel code. I’m not sure how much time I’ll have so if something really needs it they can take up this work.
It would be good to have some working version in the meantime to unblock people (like me).
I’ll take another look actually. If I ignore the split k specialization this seems very doable.
Just curious what is the last column measuring? It's a difference from what to what exactly? CPU -> GPU?
No it's actually mlx vs pytorch. They should scale similarly so I use these numbers to measure performance.
Also small update: I'm trying to parallelize the groups
for loop by sending each kernel to a different command buffer. So I will create groups
streams, groups
command queues, etc... Working through some errors right now, but lmk if that makes sense
Also small update: I'm trying to parallelize the groups for loop by sending each kernel to a different command buffer. So I will create groups streams, groups command queues, etc... Working through some errors right now, but lmk if that makes sense
Actually, I would not do that. That is going to introduce a lot of overhead and subvert how we do job submission for the GPU.
The best approach is to have a single kernel to do all the groups and handle that extra dimension in the thread grid or something like that. But I realize that might be a lot more work.
A less good option that you could try is to use a concurrent command encoder. If you rebase on main, you will get some functionality to make that much easier.
Here is a very simple example of how we do that in concatenate now: https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/primitives.cpp#L556-L564
Thanks for guiding me in the right direction! Numbers looks very good now and it’s review for review.
N | iH | C | O | wH | C | dtype | stride | pads | groups | diff% |
---|---|---|---|---|---|---|---|---|---|---|
4 | 32 | 32 | 32 | 5 | 32 | float32 | 1 | 2 | 1 | +189.00% |
4 | 32 | 32 | 32 | 5 | 32 | float32 | 1 | 2 | 2 | +176.95% |
4 | 32 | 32 | 32 | 5 | 32 | float32 | 1 | 2 | 4 | +185.48% |
4 | 32 | 32 | 32 | 5 | 32 | float32 | 1 | 2 | 8 | +183.16% |
4 | 32 | 32 | 32 | 5 | 32 | float32 | 1 | 2 | 8 | +181.10% |
4 | 32 | 32 | 32 | 5 | 32 | float32 | 1 | 2 | 16 | +145.79% |
4 | 32 | 32 | 32 | 5 | 32 | float32 | 1 | 2 | 32 | +102.98% |
4 | 32 | 256 | 512 | 5 | 256 | float32 | 1 | 2 | 2 | +110.27% |
4 | 32 | 256 | 512 | 5 | 256 | float32 | 1 | 2 | 128 | +50.08% |
4 | 32 | 256 | 512 | 5 | 256 | float32 | 1 | 2 | 256 | +28.68% |
Very nice result!! Will review soon.
@Rifur13 did you do any benchmarking for the CPU version? It's not a super high priority to make it fast, but we also don't want to make it worse than it was
There’s an extra copy so in theory it should be worse but I didn’t see a noticeable difference in my tests. The code for convolutions when groups = 1 is unchanged now, so the performance is identical as before.
I refactored the code to remove this copy and I think it also looks a lot cleaner. It’s easier to understand the code for groups vs without groups.
Any notes or concerns?
Not really on my side. I think we can merge this, results are very nice and code looks good!! @jagrit06 or @angeloskath do either of you care to take a quick look?
@jagrit06 Good catch I’ll add a comment for the jvp.
The existing gemm kernel using the 3rd grid dim as the batch size. Are you suggesting to repurpose batches as groups? Readability would take a hit imo.
I think it’s possible if we set:
params->batch_ndim = 1
params->batch_stride_a = K
params->batch_stride_b = N * K
params->batch_stride_d = N
@jagrit06 Good catch I’ll add a comment for the jvp.
The existing gemm kernel using the 3rd grid dim as the batch size. Are you suggesting to repurpose batches as groups? Readability would take a hit imo.
I think it’s possible if we set:
params->batch_ndim = 1 params->batch_stride_a = K params->batch_stride_b = N * K params->batch_stride_d = N
Exactly as you suggest, we can set the batch strides to let the tid.z
handle that
I don't particularly think this is a bad enough readability hit for us to include the overhead of compiling and packing all new sets of gemm kernels which are basically the same as the ones we already have
Thanks!
Done! Thanks for all the suggestions.
Ready for a final review.
@Rifur13 the conv 1d test failed. Do you mind checking it?
Tests should pass now. Tricky one..
It's failling metal validation. You should be able to reproduce locally with:
METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python ..
Fixed! Probably a good idea to add these test options in the docs somewhere
@Rifur13 sorry for the delay in merging this caused a conflict. If you can fix it we can merge asap. I also don't mind fixing the conflict tomorrow sometime.
Rebased. Should be fixed now
Proposed changes
Adding groups to 1D convolutions. Resolves #237.
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes