Open colehaus opened 2 months ago
Actually, I think those peak usages may be misleading. The problem may be something else. Even with an explicitly split model that has a demonstrably small gradient pytree we get very similar behavior:
@jax.value_and_grad
@jax.jit
@ft.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
def split_lora_grad_fn(
malleable: PartOf[eqx.nn.Linear[Dim1, Dim2, Float]], frozen: PartOf[eqx.nn.Linear[Dim1, Dim2, Float]], input_: ndarray[Dim1, Float]
) -> ndarray[Float]:
model = quax.quaxify(eqx.combine(malleable, frozen))
return jnp.square(jnp.mean(model.__call__(input_)) - 1)
@ft.partial(jax.jit, donate_argnums=0)
def split_lora_prim_step(
model: eqx.nn.Linear[Dim1, Dim2, Float],
input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
loraed = jtu.tree_map_with_path(lambda path, _: path[-2:] != (jtu.GetAttrKey("weight"), jtu.GetAttrKey("_w")), model) # type: ignore
malleable, frozen = eqx.partition(model, loraed)
del loraed, model
_, grads = split_lora_grad_fn(malleable, frozen, input_)
print("grad size", human_bytes(tree_size(grads)))
lr = 1e-3
malleable = jax.tree.map(lambda o, c: o - lr * c, malleable, grads) # type: ignore
return eqx.combine(malleable, frozen)
def split_lora_prim_main():
model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
model = quax.examples.lora.loraify(model, rank=64, scale=0.1, key=jax.random.PRNGKey(1))
# ir = split_lora_prim_step.lower(model, np.random.rand(dim1).astype(np.float32)).compiler_ir()
# ir.dump()
print("model size", human_bytes(tree_size(model)))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
for _ in range(1):
model = split_lora_prim_step(model, np.random.rand(dim1).astype(np.float32))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
model size 272.25 MB
peak usage 272.25 MB
grad size 16.25 MB
peak usage 628.51 MB
And they OOM in similar ways.
Sorry to be noisy, but I think I'm thoroughly confused now. A 34GB model with a 34GB gradient pytree somehow fits on a 48GB GPU? (We finally OOM at 35 layers and 35GB.)
from __future__ import annotations
import functools as ft
from typing import Any, Generic, TypeVar
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import quax
from numpy import ndarray
jax.config.update("jax_threefry_partitionable", True)
Dim1 = TypeVar("Dim1", bound=int)
Float = TypeVar("Float", bound=float)
def tree_size(tree: Any) -> int:
return sum(x.nbytes for x in jtu.tree_leaves(tree) if eqx.is_array(x))
def human_bytes(size: float, decimal_places: int = 2) -> str:
unit = "B"
for unit in ["B", "KB", "MB", "GB", "TB"]: # noqa: B007
if size < 1024.0: # noqa: PLR2004
break
size /= 1024.0
formatted_num = f"{size:.{decimal_places}f}".rstrip("0").rstrip(".")
return f"{formatted_num:>4} {unit}"
class Model(eqx.Module, Generic[Dim1, Float]):
layers: tuple[eqx.nn.Linear[Dim1, Dim1, Float], ...]
def __init__(self, dim: Dim1, *, num_layers: int, dtype: type[Float], key: jax.Array):
self.layers = tuple(eqx.nn.Linear(dim, dim, dtype=dtype, key=key_) for key_ in jax.random.split(key, num_layers))
def __call__(self, x: ndarray[Dim1, Float]) -> ndarray[Dim1, Float]:
for layer in self.layers:
x = layer(x)
return x
@jax.value_and_grad
@jax.jit
def grad_fn(model: Model[Dim1, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
return jnp.square(jnp.mean(model.__call__(input_)) - 1)
@ft.partial(jax.jit, donate_argnums=0)
def step(
model: Model[Dim1, Float],
input_: ndarray[Dim1, Float],
) -> Model[Dim1, Float]:
_, grads = grad_fn(model, input_)
print("grad size", human_bytes(tree_size(grads)))
lr = 1e-3
return jax.tree.map(lambda o, c: o - lr * c, model, grads) # type: ignore
dim1 = 16_384
def print_live_buffer_total():
print(human_bytes(sum([x.nbytes for x in jax.live_arrays()])))
def main():
model = Model(dim1, num_layers=34, dtype=np.float32, key=jax.random.PRNGKey(0))
print("model size", human_bytes(tree_size(model)))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
for _ in range(2):
model = step(model, np.random.rand(dim1).astype(np.float32))
print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
model size 34 GB
peak usage 34 GB
grad size 34 GB
peak usage 35.1 GB
Description
I think this may be as much a feature request as a bug report. When training a LoRA (this problem would arise elsewhere too), the underlying full weight matrices are
stop_gradient
ed so their grads are all 0. But in a LoRA scenario, this means that almost all of the gradient pytree for a model is devoted to zeros. If that gradient pytree is "consumed" (i.e. applied to the model) within a single JIT block, it seems like it should be possible in principle for JAX/XLA to avoid the memory bloat of allocating all those zero arrays. In fact, it seems like this optimization does happen sometimes but not always (see below). Is there a way to ensure this optimization always happens? Is this reasonable to hope for?See https://github.com/patrick-kidger/quax/issues/28 for more context.
System info (python version, jaxlib version, accelerator, etc.)