Closed djphoenix closed 1 month ago
The way you implemented it is cool. But it is definitely expected that it will be much slower than a handwritten kernel (even a CPU one), mostly due to overhead for all the individual small ops. Even with mx.compile
it will be a lot slower.
You might need a custom extension to make it much faster.
@awni that was what I think about last hours.
The key question is, since entire nn
module now built only in python level, what point I should start from to extend mlx.nn
in C?
Adding extensions to MLX is not the simplest process right now, I am hoping to merge a PR to at least fix it. Then you can check the docs on it once that's done https://ml-explore.github.io/mlx/build/html/dev/extensions.html
To do a proper kernel would require some familiarity with Metal
That's no problem with Metal and C++ for me. I can't understand now, should it be a part of MLX, or a separate project? I can prepare a PR implementing it as part of MLX project if you agree.
I think the CTC loss is common enough (and also difficult to implement) that we could eventually put it in MLX in the fast namespace and have it be a loss in nn.losses
. But it's a big project and might take a while to incorporate. You can start with implementing it as a separate extension and then PR it once it's good or PR directly.. up to you.
Thanks for reply, I will try.
@awni asking here because related.
Now I have a CPU-only implementation that about ~6-7 times slower than torch implementation (almost constant slow-down over time/batch size).
I see that torch implementation use multithreading for per-batch parallel execution, does MLX have something similar? It may significantly speed up CPU execution.
Anyway, now it is GPU time)
does MLX have something similar? It may significantly speed up CPU execution.
Not really. For now single threaded is ok. We need to figure out our multi-threading within an op on the CPU story. Something like openmp would probably work well for that case.
FWIW I use your latest PR for extensions (with nanobind), works well for me.
Seems like I have working version of CPU+GPU kernel. Will publish in a separate repo after some final polishing.
Very cool!!!
That looks really nicely done! I'd love to see some benchmarks in your repo. Also it would be very nice to have a little README in there explaining install + usage.
Do you mind if I link to it in our community lead projects? https://github.com/ml-explore/mlx/discussions/654
Sure, public listing will be nice. And, sure, benchmarks and more wide documentation are in TODO.
I am trying to get CTC loss with MLX. It seems that now I have a working prototype that get loss and grad values matching pytorch, but it is extremely slow for me. Especially, it is 15x slower than
torch.ctc_loss
with T=128, and 45x with T=1024.My experimental script can be found here: https://gist.github.com/djphoenix/da473afc228e73bb8bf4d6eebaf20ae3
Desktop:
Mac14,6
,MNWA3T/A
)