ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.22k stars 999 forks source link

Optimization Plans for Conv2D CPU Execution #1130

Closed a-turker1 closed 2 months ago

a-turker1 commented 6 months ago

Hello, I wonder if there are any future plans to optimize Conv2D CPU execution. I guess currently MLX uses a naive implementation?

awni commented 5 months ago

We don't have an immediate plan for this. I'm sure there is potential to make CPU convolutions faster. I'm curious though, why not use the GPU instead?

briancpark commented 5 months ago

Has there been an attempt or discussion to port in BNNS conv into MLX? It's listed as TODO. I've looked into it personally, but I'm noticing some limitations with BNNS API and MLX. For one, different format preferences. I believe BNNS prefers NCHW and OIHW, instead of NHWC and OHWI in MLX. I assume the latter is chosen as implementing the Metal variant was priority, and it has better performance properties under those formats in GPU-land.

For me, the answer for why not use GPU instead is that CPU could be more efficient. I've seen some cases in CoreML where CNNs compiled under GPU only perform as same or marginally better than CNNs compiled on CPU-only. And it could be the case the CPU is more energy efficient? (I never confirmed).

awni commented 5 months ago

I don't think we've benchmarked the convs in BNNS but it would be interesting to see how they perform. There would have to be a copy in and out as everything in MLX assumes channels are last so that might hinder perf.

Regarding efficiency, the CPU will likely be faster (with a good implementation) for smaller models, the GPU for larger models. As for power efficiency, I do not know how they stack up and how it changes with scale.. a very good question.

mlaves commented 2 months ago

I also noticed the slowness of conv on the cpu when I implemented conv3d in #993. I looked into it and the operation on the cpu could be substantially faster when using explicit_gemm_conv_ND_cpu instead of the naive implementation. However, for conv2d and conv3d, using explicit_gemm_conv_ND_cpu was slower in total, because the necessary reshaping of the inputs (happening here) is very slow. PyTorch's conv cpu implementation is 10x faster on my M2 Pro, and they also use gemm-conv.

Edit: I just did a quick benchmark on my M2 Pro for explicit_gemm_conv_2D_cpu with the parameters N = 4, iH = 32, iW = 32, C = 32, O = 32, kH = 5, kW = 5:

  1. 0.10 ms for setting things up (everything prior to copy(in_strided_view, in_strided, CopyType::General))
  2. 24.5 ms for copy(in_strided_view, in_strided, CopyType::General) alone
  3. 0.43 ms for cblas_sgemm
  4. 25.03 ms total runtime

Edit: I proposed a fix in #1410.