Open ynotzort opened 1 year ago
Hi @ynotzort
It appears that this issue has been resolved in later versions of JAX. I tested the issue on colab with GPU Tesla T4 and on WSL2 with GPUs RTX A5000 and GeForce RTX 2060 and it work fine without any failures.
from typing import Sequence
import jax
import jax.numpy as jnp
import flax.linen as nn
from jax.experimental.maps import xmap, SerialLoop, serial_loop
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([6, 3, 2])
batch = jnp.ones((2, 4))
variables = model.init(jax.random.PRNGKey(0), batch)
x_apply = xmap(
model.apply,
in_axes=(
[...],
['x_batch', ...],
),
out_axes=['x_batch',...],
axis_resources={'x_batch': 't'},
)
b_batch = jnp.ones((2, 4, 4))
with serial_loop('t', 2):
output = x_apply(variables, b_batch)
output
Output:
Array([[[1.7634937, 1.2312332],
[1.7634937, 1.2312332],
[1.7634937, 1.2312332],
[1.7634937, 1.2312332]],
[[1.7634937, 1.2312332],
[1.7634937, 1.2312332],
[1.7634937, 1.2312332],
[1.7634937, 1.2312332]]], dtype=float32)
Attaching the colab gist for reference. Also please find the below screenshot on WSL2 with GPU GeForce RTX 2060.
Note: xmap
is deprecated in JAX 0.4.26 and will be removed in future release. It recommended to use jax.experimental.shard_map
or jax.vmap
with spmd_axes_name
argument to express SPMD device-parallel computations.
Thank you.
Description
When trying to xmap an axis and to use a SerialLoop (in order to reduce memory usage) on a function that evaluates a neural network the following error is raised:
This does not happen if Mesh is used, or no resource mapping is used at all. Same happens for Equinox, but not for very simple functions like jnp.sin...
How to reproduce with FLAX:
What jax/jaxlib version are you using?
jax v0.4.5, jaxlib 0.4.4+cuda11.cudnn82
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.10.9, Linux
NVIDIA GPU info