ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.36k stars 930 forks source link

GPU Memory Management? #742

Open igm503 opened 6 months ago

igm503 commented 6 months ago

In pytorch, if I delete a reference to a torch.tensor that has been assigned to the mps device and then call torch.mps.empty_cache(), the memory allocated for that tensor will be freed.

In MLX, however, if I delete a reference to an mx.array that has been evaluated, there's no way to get the memory back. If I am very careful to delete references, I can ensure that the memory reserved by the python process doesn't exceed the maximum instantaneous memory requirement by an eval() call, but in some cases, it's difficult to delete the references, or deleting the references doesn't seem to make a difference to the MLX memory allocator.

Consider the script below:

import mlx.core as mx

fn = lambda x, y: x @ y

a = mx.random.normal(shape=(10000, 10000))
b = mx.random.normal(shape=(10000, 10000))

mx.eval(fn(a, b))

# Memory: 1.13 GB

c = mx.random.normal(shape=(10000, 10000))
d = mx.random.normal(shape=(10000, 10000))

mx.eval(fn(c, d))

# Memory: 1.88 GB

a = None
b = None
c = None
d = None

e = mx.random.normal(shape=(10000, 10000))
f = mx.random.normal(shape=(10000, 10000))

mx.eval(fn(e, f))

# Memory: 1.88 GB

g = mx.random.normal(shape=(10000, 10000))
h = mx.random.normal(shape=(10000, 10000))

mx.eval(fn(g, h))

# Memory: 1.88 GB

Setting a, b, c, and d to None has the effect of preventing the e, f, g, and h tensors from increasing the memory usage but not of decreasing the memory usage. I assume that this is because the memory that was used by a, b, c, and d is used in turn by e, f, g, and h, but the memory allocator for the python process still held on to the chunk of memory throughout.

This seems like it's fine for model training and inference loops, where references are repeatedly going to be recycled, but it causes problems elsewhere.

Consider Tristan Bilot's benchmarking repo, for example: to benchmark ~12 MLX ops with a few input sizes each, he has to create a separate process for each op and shape combination and then garbage collect once the process has joined in order to prevent memory usage from blowing up past 25 GB, despite the fact that the single biggest memory allocation required in the benchmark, from what I've tested, is about 7 GB.

Is MLX designed in such a way that would make it possible to have something like torch.mps.empty_cache()? Also, am I missing anything important conceptually here?

awni commented 6 months ago

Your description sounds conceptually correct to me. I'm curious how did you measure the memory use in your program (e.g. # Memory: 1.88 GB)?

MLX has a memory buffer cache because device memory allocations are expensive. So MLX will not return "freed" arrays to the system immediately. Rather they get held in the buffer cache and possibly reused. Which is why clearing the reference to an array doesn't always reduce the memory use reported by the system.

If the cache gets to big then MLX will garbage collect and actually give back free buffers back to the device.

Is MLX designed in such a way that would make it possible to have something like torch.mps.empty_cache()

An empty_cache would not be so difficult to add from an implementation standpoint, but it kind of breaks the abstraction. For the most part, users aren't supposed to worry about memory management. So I'm also curious, what do you want that for? If you explain the issue you are having maybe there is a more principled fix we can consider.

igm503 commented 6 months ago

To get the memory figures, I put a time.sleep(10) call where the comments are and noted the reported usage at those points from the os Activity Monitor.

Hmmm, do you know what triggers garbage collection (or where I could look in the repo to find out more)?

Re: the use case, that's a fair question! I was initially thinking of scenarios where multiple models with large memory requirements were running, but I suppose that as long as MLX's garbage collector is doing a reasonably good job, there shouldn't be an there.

Thinking about it some more, I suppose the core of the issue I mentioned with the benchmarking repo is that MLX doesn't seem to play nice with PyTorch--when trying to run PyTorch ops on the MPS backend and MLX ops that have large inputs, PyTorch will raise an exception because it isn't able to access the memory holed up by MLX. The general concern would then be that if someone was trying to do inference with two models, one on torch, the other on mlx, that each fit within memory, they might nevertheless be unable to because MLX might hang on to more memory than it absolutely needs for some reason or another. In a case like that, having a cache release function would be nice.

What do you think?

awni commented 6 months ago

Hmmm, do you know what triggers garbage collection (or where I could look in the repo to find out more)?

Here

Thinking about it some more, I suppose the core of the issue I mentioned with the benchmarking repo is that MLX doesn't seem to play nice with PyTorch--when trying to run PyTorch ops on the MPS backend and MLX ops that have large inputs, PyTorch will raise an exception because it isn't able to access the memory holed up by MLX.

Makes total sense. It would be nice if we could rely on the system allocator to manage this, but it is too slow. So because of that we will likely need some level of caching. If that's taken as a given it may be unavoidable to either have some methods to change the caching GC limits / clearing the cache / etc. Generally speaking we also, want to improve our caching strategy + the way we use memory so that may also help a bit.

igm503 commented 6 months ago

Thanks for the link!

So this might be somewhere on the to do list, albeit pretty far down, if it is?

I've also noticed that evaluating arrays with new shapes or data types seems to always cause mlx to allocate new memory, even if references to previous arrays that required much more memory have been removed. Is that for contiguity requirements or something like that, like trying to avoid whatever time it would take to repurpose the already allocated memory for the new shape or data type? Realized, I think, that your earlier comment implies that memory is allocated as a unit for each array, which I take to mean that it couldn't in most cases be repurposed for a new shape or data type.

Anyways, this is interesting, and I don't know very much about memory management, so I'm gonna take a look around!

awni commented 6 months ago

So this might be somewhere on the to do list, albeit pretty far down, if it is?

I would not say it's far down the list. This has been a priority and is starting to become a top priority.

CC a few people that have been looking into / dealing with memory related issues in MLX (@vj-krish @bpkeene @davidkoski).

I think everyone has run into the related problem which is that MLX can use way too much memory for token generation. Which is related to the cache and to how we preallocate memory ahead of jobs getting scheduled for the device.

Maybe we can use this thread to discuss plans / collaborate / make sure we don't step on each others toes w.r.t. to future improvements.

davidkoski commented 6 months ago

FWIW, footprint is a good way to measure instantaneous memory use (understanding memory use in a modern OS is tricky but this tool is a good one to give you a single number). There are lots of things that take up memory -- if you are only interested in these buffers you can look for IOAccelerator