jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.17k stars 2.76k forks source link

Memory bug with slight change in parameterization #11007

Open albertfgu opened 2 years ago

albertfgu commented 2 years ago

I have a model where the core computational kernel is a broadcast-reduce pattern that implicitly requires a large tensor, but XLA usually optimizes it away to not materialize the tensor. However, an innocuous change in the parameter initialization in Flax causes the memory to blow up.

This is the core computation:

from functools import partial
import jax
import jax.numpy as np
import optax
from flax import linen as nn
from flax.training import train_state
from jax.nn.initializers import normal
from tqdm import tqdm

rng = jax.random.PRNGKey(1)

@jax.jit
def cauchy(omega, lambd):
    """ signature: (l), (n) -> (l) """
    cauchy_dot = lambda _omega: (1. / (_omega - lambd)).sum()
    return jax.vmap(cauchy_dot)(omega)

I'll use N to denote the length of lambd, and L to denote the length of omega. Note that this function implicitly produces an intermediate tensor (1/(omega-lambd)) of shape (L, N) , but this theoretically does not need to be materialized.

We wrap it very thinly in a Flax Module:

class TestLayer(nn.Module):
    N: int
    L: int

    def setup(self):
        self.x = self.param("x", normal(dtype=np.complex64), (self.N,))
        self.y = self.param("y", normal(dtype=np.complex64), (self.N,))
        # self.x = self.param("x", normal(), (self.N, 2))
        # self.x = self.x[..., 0] + 1j * self.x[..., 1]
        # self.y = self.param("y", normal(), (self.N, 2))
        # self.y = self.y[..., 0] + 1j * self.y[..., 1]

        self.z = cauchy(np.arange(self.L), self.x*self.y).real

    def __call__(self, u):
        return u + self.z

We then vmap it to include an extra "channel" dimension called C.

# Broadcast over (C) channels
# The cauchy() call maps params x, y of shape (N, C) to z of shape (L, C)
TestLayer = nn.vmap(
    TestLayer,
    in_axes=1,
    out_axes=1,
    variable_axes={"params": 1},
    split_rngs={"params": True},
)

After this vmap, the call to cauchy takes a parameter of shape (N, C) to an output of shape (L, C) while implicitly defining a tensor of shape (L, N, C). (This intermediate tensor can be confirmed by inspecting the jaxpr.) However, this tensor does not need to be materialized, which is reflected in the memory usage reported below.

The following completes the minimal working example by feeding in dummy inputs to this Module:

# Broadcast over (B) batch
TestLayer = nn.vmap(
    TestLayer,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-B", type=int, default=16) # batch size
    parser.add_argument("-C", type=int, default=256) # broadcast channelsh
    parser.add_argument("-N", type=int, default=256) # state size N
    parser.add_argument("-L", type=int, default=1024) # seq len L
    args = parser.parse_args()

    in_shape = (args.B, args.L, args.C) # Input shape

    model = TestLayer(N=args.N, L=args.L)
    params = model.init({"params": rng}, np.ones(in_shape))

    # Create optimizer
    tx = optax.adam(learning_rate=1e-3)
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

    @partial(jax.jit, static_argnums=(2,))
    def train_step(state, inputs, model):
        def loss_fn(params):
            outputs = model.apply(params, inputs)
            loss = np.mean(outputs)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss

    # Loop over steps
    for step in tqdm(range(1000000)):
        inputs = jax.random.normal(rng, in_shape)
        state, _ = train_step(state, inputs, model)

Correct Behavior:

I ran the script with XLA_PYTHON_CLIENT_PREALLOCATE=false python -m s4.test2 -B 32 -C 256 -L 1024 -N 1024 on an A100 GPU. As expected, the GPU memory is constant across values of N (from 64 to 1024), around 1.3-1.4Gb for me. This indicates that the intermediate tensor of shape (L, N, C) is not being materialized.


Strange memory issues:

However, several small changes make the memory blow up for the above command.

  1. Removing the jax.jit decorator on cauchy causes memory usage to go up to 5.5Gb without changing the speed. Note that the entire train_step is wrapped in a jit call, so cauchy should still be jit'd anyways. I'm not sure why its speed is correct but the memory is high when the decorator is removed (this might not be a bug but just something I don't understand about JIT)

The thing that I think is a bug is the following:

  1. Keeping the jit decorator and replacing the the self.x and self.y parameter definitions with the four commented lines, causes the memory usage to spike to 5.7Gb (and grows with N). These initializations should clearly be equivalent...

So somehow, this tiny change seems to break the XLA compiler and cause it to materialize this (L, N, C) tensor again. Furthermore, starting from this change that has memory blowup, the following changes all lower the memory back to 1.3Gb:

  1. Keeping the commented lines but passing 1*self.y instead of self.x*self.y into the cauchy call
  2. Removing the state update state = state.apply_gradients(grads=grads) (but still calling grad_fn!)
  3. Removing the dummy inputs (since they're pretty much independent of this memory bottleneck), or even just removing the second nn.vmap that adds a batch dimension. Note that the input tensor's memory (shape (B, L, C)) is dominated by the intermediate (L, N, C) tensor

None of these really make sense to me since they don't affect the main potential memory bottleneck. I'm a relative JAX novice so maybe some of these things are expected, but I'm having trouble figuring it out. My understanding is that it seems like some precise combination of irrelevant changes is causing the XLA compiler to be unable to optimize this code.

albertfgu commented 2 years ago

These issues go away with different versions of the kernel. For example, replacing cauchy_dot = lambda _omega: (1. / (_omega - lambd)).sum() with cauchy_dot = lambda _omega: (_omega - lambd).sum() or even cauchy_dot = lambda _omega: (lambd/np.exp(_omega * lambd)).sum()

I wondered if it was related specifically to taking powers (there is a power of -2 in the backward pass with the original version), e.g. https://github.com/deepmind/optax/issues/196#issuecomment-975666744