ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.01k stars 856 forks source link

CTC Loss using MLX #968

Closed djphoenix closed 1 month ago

djphoenix commented 1 month ago

I am trying to get CTC loss with MLX. It seems that now I have a working prototype that get loss and grad values matching pytorch, but it is extremely slow for me. Especially, it is 15x slower than torch.ctc_loss with T=128, and 45x with T=1024.

My experimental script can be found here: https://gist.github.com/djphoenix/da473afc228e73bb8bf4d6eebaf20ae3

# python check.py
Log-probs shape (time X batch X channels): 128x64x32
Builtin time: 0.003s, value=6.04638
MLX time: 0.045s (1610.97%) value=6.04638
Loss match True
Grad match True

# python check.py
Log-probs shape (time X batch X channels): 1024x64x32
Builtin time: 0.174s, value=5.36670
MLX time: 7.689s (4413.85%) value=5.36670
Loss match True
Grad match True

Desktop:

awni commented 1 month ago

The way you implemented it is cool. But it is definitely expected that it will be much slower than a handwritten kernel (even a CPU one), mostly due to overhead for all the individual small ops. Even with mx.compile it will be a lot slower.

You might need a custom extension to make it much faster.

djphoenix commented 1 month ago

@awni that was what I think about last hours. The key question is, since entire nn module now built only in python level, what point I should start from to extend mlx.nn in C?

awni commented 1 month ago

Adding extensions to MLX is not the simplest process right now, I am hoping to merge a PR to at least fix it. Then you can check the docs on it once that's done https://ml-explore.github.io/mlx/build/html/dev/extensions.html

To do a proper kernel would require some familiarity with Metal

djphoenix commented 1 month ago

That's no problem with Metal and C++ for me. I can't understand now, should it be a part of MLX, or a separate project? I can prepare a PR implementing it as part of MLX project if you agree.

awni commented 1 month ago

I think the CTC loss is common enough (and also difficult to implement) that we could eventually put it in MLX in the fast namespace and have it be a loss in nn.losses. But it's a big project and might take a while to incorporate. You can start with implementing it as a separate extension and then PR it once it's good or PR directly.. up to you.

djphoenix commented 1 month ago

Thanks for reply, I will try.

djphoenix commented 1 month ago

@awni asking here because related.

Now I have a CPU-only implementation that about ~6-7 times slower than torch implementation (almost constant slow-down over time/batch size).

I see that torch implementation use multithreading for per-batch parallel execution, does MLX have something similar? It may significantly speed up CPU execution.

Anyway, now it is GPU time)

awni commented 1 month ago

does MLX have something similar? It may significantly speed up CPU execution.

Not really. For now single threaded is ok. We need to figure out our multi-threading within an op on the CPU story. Something like openmp would probably work well for that case.

djphoenix commented 1 month ago

FWIW I use your latest PR for extensions (with nanobind), works well for me.

djphoenix commented 1 month ago

Seems like I have working version of CPU+GPU kernel. Will publish in a separate repo after some final polishing.

awni commented 1 month ago

Very cool!!!

djphoenix commented 1 month ago

Finally published on GitHub and PyPI. If you have a time, feel free to file a review.

awni commented 1 month ago

That looks really nicely done! I'd love to see some benchmarks in your repo. Also it would be very nice to have a little README in there explaining install + usage.

Do you mind if I link to it in our community lead projects? https://github.com/ml-explore/mlx/discussions/654

djphoenix commented 1 month ago

Sure, public listing will be nice. And, sure, benchmarks and more wide documentation are in TODO.