Closed a-turker1 closed 2 months ago
We don't have an immediate plan for this. I'm sure there is potential to make CPU convolutions faster. I'm curious though, why not use the GPU instead?
Has there been an attempt or discussion to port in BNNS conv into MLX? It's listed as TODO. I've looked into it personally, but I'm noticing some limitations with BNNS API and MLX. For one, different format preferences. I believe BNNS prefers NCHW and OIHW, instead of NHWC and OHWI in MLX. I assume the latter is chosen as implementing the Metal variant was priority, and it has better performance properties under those formats in GPU-land.
For me, the answer for why not use GPU instead is that CPU could be more efficient. I've seen some cases in CoreML where CNNs compiled under GPU only perform as same or marginally better than CNNs compiled on CPU-only. And it could be the case the CPU is more energy efficient? (I never confirmed).
I don't think we've benchmarked the convs in BNNS but it would be interesting to see how they perform. There would have to be a copy in and out as everything in MLX assumes channels are last so that might hinder perf.
Regarding efficiency, the CPU will likely be faster (with a good implementation) for smaller models, the GPU for larger models. As for power efficiency, I do not know how they stack up and how it changes with scale.. a very good question.
I also noticed the slowness of conv
on the cpu when I implemented conv3d
in #993. I looked into it and the operation on the cpu could be substantially faster when using explicit_gemm_conv_ND_cpu instead of the naive implementation. However, for conv2d
and conv3d
, using explicit_gemm_conv_ND_cpu
was slower in total, because the necessary reshaping of the inputs (happening here) is very slow.
PyTorch's conv
cpu implementation is 10x faster on my M2 Pro, and they also use gemm-conv.
Edit: I just did a quick benchmark on my M2 Pro for explicit_gemm_conv_2D_cpu
with the parameters
N = 4, iH = 32, iW = 32, C = 32, O = 32, kH = 5, kW = 5
:
copy(in_strided_view, in_strided, CopyType::General)
)copy(in_strided_view, in_strided, CopyType::General)
alonecblas_sgemm
Edit: I proposed a fix in #1410.
Hello, I wonder if there are any future plans to optimize Conv2D CPU execution. I guess currently MLX uses a naive implementation?