jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.42k stars 2.79k forks source link

Full zero arrays allocated for "consumed" gradients of `stop_gradient` parameters #23316

Open colehaus opened 2 months ago

colehaus commented 2 months ago

Description

I think this may be as much a feature request as a bug report. When training a LoRA (this problem would arise elsewhere too), the underlying full weight matrices are stop_gradiented so their grads are all 0. But in a LoRA scenario, this means that almost all of the gradient pytree for a model is devoted to zeros. If that gradient pytree is "consumed" (i.e. applied to the model) within a single JIT block, it seems like it should be possible in principle for JAX/XLA to avoid the memory bloat of allocating all those zero arrays. In fact, it seems like this optimization does happen sometimes but not always (see below). Is there a way to ensure this optimization always happens? Is this reasonable to hope for?

See https://github.com/patrick-kidger/quax/issues/28 for more context.

from __future__ import annotations

import functools as ft
from typing import Any, TypeVar, TypeVarTuple

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import quax
from numpy import ndarray

BatchLen = TypeVar("BatchLen", bound=int)
Dim1 = TypeVar("Dim1", bound=int)
Dim2 = TypeVar("Dim2", bound=int)
Dim3 = TypeVar("Dim3", bound=int)
Rank = TypeVar("Rank", bound=int)
Float = TypeVar("Float", bound=float)
Shape = TypeVarTuple("Shape")
A = TypeVar("A")
Opt = TypeVar("Opt")

def tree_size(tree: Any) -> int:
    return sum(x.nbytes for x in jax.tree_util.tree_leaves(tree) if eqx.is_array(x))

def human_bytes(size: float, decimal_places: int = 2) -> str:
    unit = "B"
    for unit in ["B", "KB", "MB", "GB", "TB"]:  # noqa: B007
        if size < 1024.0:  # noqa: PLR2004
            break
        size /= 1024.0

    formatted_num = f"{size:.{decimal_places}f}".rstrip("0").rstrip(".")
    return f"{formatted_num:>4} {unit}"

@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def full_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
    return jnp.square(jnp.mean(model.__call__(input_)) - 1)

@eqx.filter_jit(donate="all")
def full_prim_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    _, grads = full_grad_fn(model, input_)
    lr = 1e-3
    return jax.tree.map(lambda o, c: o - lr * c, model, grads)  # type: ignore

@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def lora_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
    model = quax.quaxify(model)
    return jnp.square(jnp.mean(model.__call__(input_)) - 1)

@eqx.filter_value_and_grad
@eqx.filter_jit
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def no_op_grad_fn(model: eqx.nn.Linear[Dim1, Dim2, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
    return jnp.array(0, input_.dtype)

@eqx.filter_jit(donate="all")
def lora_prim_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    _, grads = lora_grad_fn(model, input_)
    lr = 1e-3
    return jax.tree.map(lambda o, c: o - lr * c, model, grads)  # type: ignore

@eqx.filter_jit(donate="all")
def no_op_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    _, grads = no_op_grad_fn(model, input_)
    lr = 1e-3
    return jax.tree.map(lambda o, c: o - lr * c, model, grads)  # type: ignore

dim1 = 65536
dim2 = 1024

def print_live_buffer_total():
    print(human_bytes(sum([x.nbytes for x in jax.live_arrays()])))

def full_prim_main():
    # OOMs on 75_000 but not 70_000
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))

    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(40):
        model = full_prim_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))

def lora_prim_main():
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
    model = quax.examples.lora.loraify(model, rank=64, scale=0.1, key=jax.random.PRNGKey(1))

    print_live_buffer_total()
    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(40):
        model = lora_prim_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))

def no_op_main():
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))

    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(40):
        model = no_op_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))

image

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.1
python: 3.11.9 (main, Apr  6 2024, 17:59:24) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='npjfe11cq9', release='5.19.0-45-generic', version='#46~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Jun 7 15:06:04 UTC 20', machine='x86_64')

$ nvidia-smi
Sun Aug 11 01:03:45 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   42C    P8    27W [/](https://file+.vscode-resource.vscode-cdn.net/) 300W |  36785MiB [/](https://file+.vscode-resource.vscode-cdn.net/) 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
colehaus commented 2 months ago

Actually, I think those peak usages may be misleading. The problem may be something else. Even with an explicitly split model that has a demonstrably small gradient pytree we get very similar behavior:

@jax.value_and_grad
@jax.jit
@ft.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
def split_lora_grad_fn(
    malleable: PartOf[eqx.nn.Linear[Dim1, Dim2, Float]], frozen: PartOf[eqx.nn.Linear[Dim1, Dim2, Float]], input_: ndarray[Dim1, Float]
) -> ndarray[Float]:
    model = quax.quaxify(eqx.combine(malleable, frozen))
    return jnp.square(jnp.mean(model.__call__(input_)) - 1)

@ft.partial(jax.jit, donate_argnums=0)
def split_lora_prim_step(
    model: eqx.nn.Linear[Dim1, Dim2, Float],
    input_: ndarray[Dim1, Float],
) -> eqx.nn.Linear[Dim1, Dim2, Float]:
    loraed = jtu.tree_map_with_path(lambda path, _: path[-2:] != (jtu.GetAttrKey("weight"), jtu.GetAttrKey("_w")), model)  # type: ignore
    malleable, frozen = eqx.partition(model, loraed)
    del loraed, model
    _, grads = split_lora_grad_fn(malleable, frozen, input_)
    print("grad size", human_bytes(tree_size(grads)))
    lr = 1e-3
    malleable = jax.tree.map(lambda o, c: o - lr * c, malleable, grads)  # type: ignore
    return eqx.combine(malleable, frozen)

def split_lora_prim_main():
    model = eqx.nn.Linear(dim1, dim2, dtype=np.float32, key=jax.random.PRNGKey(0))
    model = quax.examples.lora.loraify(model, rank=64, scale=0.1, key=jax.random.PRNGKey(1))

   #  ir = split_lora_prim_step.lower(model, np.random.rand(dim1).astype(np.float32)).compiler_ir()
    # ir.dump()

    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(1):
        model = split_lora_prim_step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
model size 272.25 MB
peak usage 272.25 MB
grad size 16.25 MB
peak usage 628.51 MB

And they OOM in similar ways.

colehaus commented 2 months ago

Sorry to be noisy, but I think I'm thoroughly confused now. A 34GB model with a 34GB gradient pytree somehow fits on a 48GB GPU? (We finally OOM at 35 layers and 35GB.)

from __future__ import annotations

import functools as ft
from typing import Any, Generic, TypeVar

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import quax
from numpy import ndarray

jax.config.update("jax_threefry_partitionable", True)

Dim1 = TypeVar("Dim1", bound=int)
Float = TypeVar("Float", bound=float)

def tree_size(tree: Any) -> int:
    return sum(x.nbytes for x in jtu.tree_leaves(tree) if eqx.is_array(x))

def human_bytes(size: float, decimal_places: int = 2) -> str:
    unit = "B"
    for unit in ["B", "KB", "MB", "GB", "TB"]:  # noqa: B007
        if size < 1024.0:  # noqa: PLR2004
            break
        size /= 1024.0

    formatted_num = f"{size:.{decimal_places}f}".rstrip("0").rstrip(".")
    return f"{formatted_num:>4} {unit}"

class Model(eqx.Module, Generic[Dim1, Float]):
    layers: tuple[eqx.nn.Linear[Dim1, Dim1, Float], ...]

    def __init__(self, dim: Dim1, *, num_layers: int, dtype: type[Float], key: jax.Array):
        self.layers = tuple(eqx.nn.Linear(dim, dim, dtype=dtype, key=key_) for key_ in jax.random.split(key, num_layers))

    def __call__(self, x: ndarray[Dim1, Float]) -> ndarray[Dim1, Float]:
        for layer in self.layers:
            x = layer(x)
        return x

@jax.value_and_grad
@jax.jit
def grad_fn(model: Model[Dim1, Float], input_: ndarray[Dim1, Float]) -> ndarray[Float]:
    return jnp.square(jnp.mean(model.__call__(input_)) - 1)

@ft.partial(jax.jit, donate_argnums=0)
def step(
    model: Model[Dim1, Float],
    input_: ndarray[Dim1, Float],
) -> Model[Dim1, Float]:
    _, grads = grad_fn(model, input_)
    print("grad size", human_bytes(tree_size(grads)))
    lr = 1e-3
    return jax.tree.map(lambda o, c: o - lr * c, model, grads)  # type: ignore

dim1 = 16_384

def print_live_buffer_total():
    print(human_bytes(sum([x.nbytes for x in jax.live_arrays()])))

def main():
    model = Model(dim1, num_layers=34, dtype=np.float32, key=jax.random.PRNGKey(0))

    print("model size", human_bytes(tree_size(model)))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
    for _ in range(2):
        model = step(model, np.random.rand(dim1).astype(np.float32))
    print("peak usage", human_bytes(jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]))
model size   34 GB
peak usage   34 GB
grad size   34 GB
peak usage 35.1 GB