patrick-kidger / quax

Multiple dispatch over abstract array types in JAX.
Apache License 2.0
100 stars 2 forks source link

LoRA that doesn't require memory for zero gradients of the underlying matrices #28

Open colehaus opened 2 weeks ago

colehaus commented 2 weeks ago

I think one of the main motives for LoRA is to reduce memory consumption—certainly that's my motive. I'm already using gradient checkpointing and AdaFactor so the main thing I want from LoRA is to reduce the size of the gradient pytree itself. However, unless I'm quite confused, in a trivial setup like:

class DummyModel(eqx.Module, Generic[Dim1, Dim2, Float]):
    tmp: eqx.nn.Linear[Dim1, Dim2, Float]

    def __init__(self, dim1: Dim1, dim2: Dim2, dtype: type[Float], key: jax.Array) -> None:
        self.tmp = eqx.nn.Linear(dim1, dim2, dtype=dtype, key=key)

    def __call__(self, ndarray: ndarray[Dim1, Float]) -> ndarray[Dim2, Float]:
        return self.tmp(ndarray)

@eqx.filter_value_and_grad
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def grad_fn(m: DummyModel[Dim1, Dim2, Float], y: ndarray[Dim1, Float]) -> ndarray[Float]:
    m = quax.quaxify(m)
    return jnp.square(jnp.mean(m(y)) - 0)

def main():
    x = DummyModel[Dim1T, Dim2T, np.float32](4096, 4096, np.float32, jax.random.PRNGKey(0))
    loraed = loraify(x, rank=64, scale=0.1, key=jax.random.PRNGKey(1))
    return grad_fn(loraed, np.random.rand(4096))

the returned grads include a full Dim1 x Dim2 array of zeros for _w. Almost all the values in the gradient pytree are zero (for typical LoRAs) and this is wasted memory.

I thought perhaps I could get around this by replacing jax.lax.stop_gradient in LoraArray with something like:

@jax.custom_jvp
def symbolic_stop_gradient(x: A) -> A:
    return x

@symbolic_stop_gradient.defjvp
def symbolic_stop_gradient_jvp(primals: tuple[ndarray[*Shape, Float]], tangents: tuple[ndarray[*Shape, Float]]):
    return primals[0], Zero(primals[0].shape, primals[0].dtype)

but that produces the following error:

TypeError: Custom JVP rule symbolic_stop_gradient_jvp for function symbolic_stop_gradient must produce primal and tangent outputs with equal container (pytree) structures, but got PyTreeDef(*) and PyTreeDef(CustomNode(Zero[(), ('_shape', '_dtype'), ((4096, 4096), dtype('float32'))], [])) respectively.

Is there a reasonable way to use quax to implement LoRA in a way that doesn't allocate tons of space for zeros?

(I guess it's mildly possible that JAX optimizes out this allocation behind the scenes if the gradient pytree is "consumed" inside the same JIT where the gradients are produced, but I assume it's not quite that clever.)

Thanks.

patrick-kidger commented 2 weeks ago

Actually, I think JAX is exactly that clever :)

Optimizing x+0 to just x is a simple optimization that XLA should perform for us.

That said I'd be happy to adjust Quax to avoid ever emitting the +0 in the first place, but I'm not immediately sure how.

colehaus commented 2 weeks ago

Hmm. That optimization does not seem to be happening in my test case with LoRA. Both full training and LoRA have a peak memory usage that's basically double the model size but this optimization does seem to fire when we take gradients of a trivial constant function.

from __future__ import annotations

import functools as ft
from typing import Any, TypeVar, TypeVarTuple

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import quax
from numpy import ndarray

BatchLen = TypeVar("BatchLen", bound=int)
Dim1 = TypeVar("Dim1", bound=int)
Dim2 = TypeVar("Dim2", bound=int)
Dim3 = TypeVar("Dim3", bound=int)
Rank = TypeVar("Rank", bound=int)
Float = TypeVar("Float", bound=float)
Shape = TypeVarTuple("Shape")
A = TypeVar("A")
Opt = TypeVar("Opt")

def tree_size(tree: Any) -> int:
    return sum(x.nbytes for x in jax.tree_util.tree_leaves(tree) if eqx.is_array(x))

def human_bytes(size: float, decimal_places: int = 2) -> str:
    unit = "B"
    for unit in ["B", "KB", "MB", "GB", "TB"]:  # noqa: B007
        if size < 1024.0:  # noqa: PLR2004
            break
        size /= 1024.0

    formatted_num = f"{size:.{decimal_places}f}".rstrip("0").rstrip(".")
    return f"{formatted_num:>4} {unit}"

@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def full_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
    return jnp.square(jnp.mean(model.__call__(input_)) - 1)

@eqx.filter_jit(donate="all")
def full_prim_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    _, grads = full_grad_fn(model, input_)
    lr = 1e-3
    return jax.tree.map(lambda o, c: o - lr * c, model, grads)  # type: ignore

@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def lora_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
    model = quax.quaxify(model)
    return jnp.square(jnp.mean(model.__call__(input_)) - 1)

@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def no_op_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
    return jnp.array(0, input_.dtype)

@eqx.filter_jit(donate="all")
def lora_prim_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    _, grads = lora_grad_fn(model, input_)
    lr = 1e-3
    return jax.tree.map(lambda o, c: o - lr * c, model, grads)  # type: ignore

@eqx.filter_jit(donate="all")
def no_op_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    _, grads = no_op_grad_fn(model, input_)
    lr = 1e-3
    return jax.tree.map(lambda o, c: o - lr * c, model, grads)  # type: ignore

dim1 = 65536
dim2 = 1024

def print_live_buffer_total():
    print(human_bytes(sum([x.nbytes for x in jax.live_arrays()])))

def full_prim_main():
    # OOMs on 75_000 but not 70_000
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))

    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(40):
        model = full_prim_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))

def lora_prim_main():
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
    model = quax.examples.lora.loraify(model, rank=64, scale=0.1, key=jax.random.PRNGKey(1))

    print_live_buffer_total()
    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(40):
        model = lora_prim_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))

def no_op_main():
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))

    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(40):
        model = no_op_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))

image

(If this counts as a JAX bug and/or is out of scope, I'm happy to move it over to the JAX repo.)

patrick-kidger commented 2 weeks ago

Hmm, that's unfortunate if so.

Quax is still a fairly experimental library, so I'd be happy to take suggestions on how we might adjust the internals to work around this.

For example this could be accomplished by partition/combineing either side of the grad. Maybe there's a way to more easily enable that.

colehaus commented 2 weeks ago

Yeah, I'll think about the problem.

I already tried a version of the partition/combine approach for a different problem (not LoRA but a whole chunk of the model frozen) and the memory usage didn't work out there as hoped. I'll see if I can reproduce that problem, but, if not, maybe something in that region is the right thing to aim for.

colehaus commented 2 weeks ago

(I opened an issue on this optimization at https://github.com/google/jax/issues/23316.)

colehaus commented 2 weeks ago

Actually, I think those peak usages may be misleading. The problem may be something else. Even with an explicitly split model we get very similar behavior:

@jax.value_and_grad
@jax.jit
@ft.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
def split_lora_grad_fn(
    malleable: PartOf[eqx.nn.Linear[Dim1, Dim2, Float]], frozen: PartOf[eqx.nn.Linear[Dim1, Dim2, Float]], input_: ndarray[Dim1, Float]
) -> ndarray[Float]:
    model = quax.quaxify(eqx.combine(malleable, frozen))
    return jnp.square(jnp.mean(model.__call__(input_)) - 1)

@ft.partial(jax.jit, donate_argnums=0)
def split_lora_prim_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    loraed = jtu.tree_map_with_path(lambda path, _: path[-2:] != (jtu.GetAttrKey("weight"), jtu.GetAttrKey("_w")), model)  # type: ignore
    malleable, frozen = eqx.partition(model, loraed)
    del loraed, model
    _, grads = split_lora_grad_fn(malleable, frozen, input_)
    print("grad size", human_bytes(tree_size(grads)))
    lr = 1e-3
    malleable = jax.tree.map(lambda o, c: o - lr * c, malleable, grads)  # type: ignore
    return eqx.combine(malleable, frozen)

def split_lora_prim_main():
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
    model = quax.examples.lora.loraify(model, rank=64, scale=0.1, key=jax.random.PRNGKey(1))

   #  ir = split_lora_prim_step.lower(model, np.random.rand(dim1).astype(np.float32)).compiler_ir()
    # ir.dump()

    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(1):
        model = split_lora_prim_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
model size 272.25 MB
peak usage 272.25 MB
grad size 16.25 MB
peak usage 628.51 MB

And they OOM in the same way.