jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

xmap with SerialLoop and FLAX (same for equinox) #14877

Open ynotzort opened 1 year ago

ynotzort commented 1 year ago

Description

When trying to xmap an axis and to use a SerialLoop (in order to reduce memory usage) on a function that evaluates a neural network the following error is raised:

NameError: unbound axis name: t. The following axis names (e.g. defined by pmap) are available to collective operations: []

This does not happen if Mesh is used, or no resource mapping is used at all. Same happens for Equinox, but not for very simple functions like jnp.sin...

How to reproduce with FLAX:

# this is just the hello world from FLAX
from typing import Sequence

import jax
import jax.numpy as jnp
import flax.linen as nn
from jax.experimental.maps import xmap, SerialLoop, serial_loop

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([6, 3, 2])
batch = jnp.ones((2, 4))
variables = model.init(jax.random.PRNGKey(0), batch)
# output = model.apply(variables, batch)

x_apply = xmap(
    model.apply,
    in_axes=(
        [...],
        ['x_batch', ...],
    ),
    out_axes=['x_batch',...],
    axis_resources={'x_batch': 't'},
    # axis_resources={'x_batch': SerialLoop(2)},
)

b_batch = jnp.ones((2, 4, 4))
with serial_loop('t', 2):
  x_apply(variables, b_batch)

# Fails with:
# NameError: unbound axis name: t. The following axis names (e.g. defined by pmap) are available to collective operations: []

# Does not fail if serial_loop is replaced with Mesh

What jax/jaxlib version are you using?

jax v0.4.5, jaxlib 0.4.4+cuda11.cudnn82

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.10.9, Linux

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.05    Driver Version: 525.85.05    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
| N/A   55C    P0    12W /  50W |      5MiB /  4096MiB |     24%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      7914      G   /usr/lib/xorg/Xorg                  4MiB |
+-----------------------------------------------------------------------------+
rajasekharporeddy commented 5 months ago

Hi @ynotzort

It appears that this issue has been resolved in later versions of JAX. I tested the issue on colab with GPU Tesla T4 and on WSL2 with GPUs RTX A5000 and GeForce RTX 2060 and it work fine without any failures.

from typing import Sequence

import jax
import jax.numpy as jnp
import flax.linen as nn
from jax.experimental.maps import xmap, SerialLoop, serial_loop

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([6, 3, 2])
batch = jnp.ones((2, 4))
variables = model.init(jax.random.PRNGKey(0), batch)

x_apply = xmap(
    model.apply,
    in_axes=(
        [...],
        ['x_batch', ...],
    ),
    out_axes=['x_batch',...],
    axis_resources={'x_batch': 't'},
)

b_batch = jnp.ones((2, 4, 4))
with serial_loop('t', 2):
  output = x_apply(variables, b_batch)

output

Output:

Array([[[1.7634937, 1.2312332],
        [1.7634937, 1.2312332],
        [1.7634937, 1.2312332],
        [1.7634937, 1.2312332]],

       [[1.7634937, 1.2312332],
        [1.7634937, 1.2312332],
        [1.7634937, 1.2312332],
        [1.7634937, 1.2312332]]], dtype=float32)

Attaching the colab gist for reference. Also please find the below screenshot on WSL2 with GPU GeForce RTX 2060. image image

Note: xmap is deprecated in JAX 0.4.26 and will be removed in future release. It recommended to use jax.experimental.shard_map or jax.vmap with spmd_axes_name argument to express SPMD device-parallel computations.

Thank you.