ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.83k stars 968 forks source link

[Feature] memory release action in stream(to reduce memory usage) #1481

Open kaeru-shigure opened 1 week ago

kaeru-shigure commented 1 week ago

I want something like this:

import mlx.core as mx
import mlx.nn as nn

mx.metal.set_cache_limit(0)

def dload(x, *args, **kwargs):
  la = mx.load(*args, **kwargs)
  @mx.custom_function
  def dload(x):
    print(f"forward_load({args[0]})")
    return x
  @dload.vjp
  def dload(p, c, o):
    # la -> allocator::free
    print(f"backward_unload({args[0]})")
    return c

  @mx.custom_function
  def drelease(x):
    # la -> allocator::free
    print(f"forward_unload({args[0]})")
    return x
  @drelease.vjp
  def drelease(p, c, o):
    print(f"backward_load({args[0]})")
    # la = mx.load(*args, **kwargs)
    # (la -> allocator::malloc)
    return c
  return dload(x), la, drelease

def linear_down(x):
  x, la, release = dload(x, "down.safetensors")
  x = x @ la["w"] + la["b"]
  return release(x) # same as x, but run allocator::free in eval
def linear_up(x):
  x, la, release = dload(x, "up.safetensors")
  x = x @ la["w"] + la["b"]
  return release(x) # same as x, but run allocator::free in eval

def proc(x):
  for _ in range(10):
    x = linear_down(x)
    x = linear_up(x)
  return x.mean()

#init
mx.save_safetensors("down.safetensors",{"w": mx.random.normal([1, 1, 48, 64]), "b": mx.random.normal([1, 1, 1, 64])})
mx.save_safetensors("up.safetensors",{"w": mx.random.normal([1, 1, 64, 48]), "b": mx.random.normal([1, 1, 1, 48])})
mx.random.seed(3)
x = mx.random.normal([1, 3, 32, 48])

#run
r = mx.value_and_grad(proc)(x)
mx.eval(*r)
print(mx.metal.get_peak_memory())
# should only use 1x la["w"].nbytes + la["b"].nbytes + others
forward_load(down.safetensors)
forward_unload(down.safetensors)
forward_load(up.safetensors)
forward_unload(up.safetensors)
..
backward_load(up.safetensors)
backward_unload(up.safetensors)
backward_load(down.safetensors)
backward_unload(down.safetensors)
kaeru-shigure commented 1 week ago

dynamically loaded weights are not required to be learnable (because using LoRA)

awni commented 1 week ago

It looks like you already implemented what you want? Is it not working / broken?

kaeru-shigure commented 1 week ago

no, look closely... this only implements just comments output. there is no way to free up memory in stream for now.