Closed awni closed 2 weeks ago
I think we should land this. It basically gets 99% of the advantage of BFS but protects from the extreme edge cases. Based on experimentation, I think the default should be somewhere between 10-50. There will always be a trade-off between memory use and parallelism.. but that range seems pretty stable / gets most of the best of both DFS / BFS.
It might be nice to re-implement this in our graph_utils to retain the property of printing/exporing a graph in the order of evaluation.
It's a good idea. I want to think on it a bit though as naively it's a lot of code duplication (which is ok when it's simple but I'd prefer to avoid pasting this complex bit of code in other places and the graph_utils
is so nice and simple right now..). For exporting to dot it won't matter much since dot will determine the order in the image. But for printing it would be nice to have the same order.
Yeah I agree 100%. We shouldn't copy/paste. It is a good opportunity to benchmark and explore rewriting this code in a location where performance is much less critical. As far as dot goes, I kinda like that one can trace the evaluation order based on the names of the nodes which is true in dot images as well. I remember I used it a lot when reasoning about checkpointing.
I'm going with 20 as the default.. it seems to be a sweet spot for most workloads. The ResNet is a bit of an outlier because of the suboptimal gradient computation for the weights. I think fixing that with a single Conv primitive will make a smaller BFS width limit more favorable for it as well.
As far as dot goes, I kinda like that one can trace the evaluation order based on the names of the nodes which is true in dot images as well. I remember I used it a lot when reasoning about checkpointing.
Very good point. Will aim to fix that in a follow up.
Will aim to fix that in a follow up.
Yeah no worries, I didn't mean to dump it on you :-) Just thinking out-loud.
I'm going with 20 as the default.
Exactly what you wrote. Given also the 10 ops per buffer it is unlikely that we 'll get significant parallelization speedups with more than 20 primitives anyway.
Add a width limit to the BFS. Made it configurable with an env var
MLX_BFS_MAX_WIDTH
to facilitate benchmarking and keep it flexible.The default is temporarily 10 but that needs to be validated on benchmarks:
A simple benchmark like the following which is worst case for BFS will use way less RAM:
Before the width limit: 4,251 MB width 1: 689 MB width 10: 729 MB width 50: 792 MB
Benchmarks M3 Max
LLM inference Llama 3.2 1B 4-bit, 512 tokens
Transformer Training
CIFAR ResNet