ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.48k stars 1.01k forks source link

Bfs width limit #1568

Closed awni closed 2 weeks ago

awni commented 2 weeks ago

Add a width limit to the BFS. Made it configurable with an env var MLX_BFS_MAX_WIDTH to facilitate benchmarking and keep it flexible.

The default is temporarily 10 but that needs to be validated on benchmarks:

A simple benchmark like the following which is worst case for BFS will use way less RAM:

arrs = [mx.zeros((4096, 128), mx.int64) for _ in range(1000)]
arrs = [x.astype(mx.int8) for x in arrs]
mx.eval(arrs)
print(mx.metal.get_peak_memory() / 1e6)

Before the width limit: 4,251 MB width 1: 689 MB width 10: 729 MB width 50: 792 MB

Benchmarks M3 Max

LLM inference Llama 3.2 1B 4-bit, 512 tokens

Width toks/sec RAM
1 357.2 0.724 GB
5 373.2 0.723 GB
10 374.0 0.724 GB
inf 375.0 0.724 GB

Transformer Training

Width it/sec RAM
1 3.32 5.204
10 3.47 5.222
50 3.48 5.48
inf 3.48 5.48

CIFAR ResNet

Width im/sec RAM
1 3152 1.468 GB
10 3087 1.515 GB
20 3168 1.671 GB
50 3887 1.895 GB
inf 4023 2.264 GB
Quantizing Llama 8B to 4-bit Width Time (s) RAM (GB)
1 2.92 9.37
10 2.83 9.59
20 2.89 10.20
50 2.90 10.22
inf 3.14 18.17
awni commented 2 weeks ago

I think we should land this. It basically gets 99% of the advantage of BFS but protects from the extreme edge cases. Based on experimentation, I think the default should be somewhere between 10-50. There will always be a trade-off between memory use and parallelism.. but that range seems pretty stable / gets most of the best of both DFS / BFS.

awni commented 2 weeks ago

It might be nice to re-implement this in our graph_utils to retain the property of printing/exporing a graph in the order of evaluation.

It's a good idea. I want to think on it a bit though as naively it's a lot of code duplication (which is ok when it's simple but I'd prefer to avoid pasting this complex bit of code in other places and the graph_utils is so nice and simple right now..). For exporting to dot it won't matter much since dot will determine the order in the image. But for printing it would be nice to have the same order.

angeloskath commented 2 weeks ago

Yeah I agree 100%. We shouldn't copy/paste. It is a good opportunity to benchmark and explore rewriting this code in a location where performance is much less critical. As far as dot goes, I kinda like that one can trace the evaluation order based on the names of the nodes which is true in dot images as well. I remember I used it a lot when reasoning about checkpointing.

awni commented 2 weeks ago

I'm going with 20 as the default.. it seems to be a sweet spot for most workloads. The ResNet is a bit of an outlier because of the suboptimal gradient computation for the weights. I think fixing that with a single Conv primitive will make a smaller BFS width limit more favorable for it as well.

awni commented 2 weeks ago

As far as dot goes, I kinda like that one can trace the evaluation order based on the names of the nodes which is true in dot images as well. I remember I used it a lot when reasoning about checkpointing.

Very good point. Will aim to fix that in a follow up.

angeloskath commented 2 weeks ago

Will aim to fix that in a follow up.

Yeah no worries, I didn't mean to dump it on you :-) Just thinking out-loud.

I'm going with 20 as the default.

Exactly what you wrote. Given also the 10 ops per buffer it is unlikely that we 'll get significant parallelization speedups with more than 20 primitives anyway.