A sufficiently-large (7x7) nnx.Conv with float32 parameter dtype, when sharded across multiple devices, generates nan, seemingly due to overflow.
The nan is avoided by making any one of the following changes:
Use smaller convolution, e.g. 3x3
Enable param_dtype='float64' inside with jax.experimental.enable_x64()
Disable sharding; run on one device only
float32 is sufficient for 7x7 convolution when not sharded, suggesting that it ought to work when sharded as well.
The issue can be worked around but all of the workarounds are unsatisfactory in some way, either by reducing the size of the convolution, or requiring double the memory usage, or restricting training to a single device.
Minimal example. This code works because it's moved to float64:
from flax import nnx
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import optax
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)
W_SMALL=700
H_SMALL=700
W_LARGE=1400
H_LARGE=1400
C=3
DEVICES = jax.devices('gpu')[1:]
# Make sure we're a multiple of # of devices
# This lets shard us by the batch dimension
TRAIN_BATCH=5 * len(DEVICES)
class Model(nnx.Module):
def __init__(self):
init_fn = nnx.initializers.lecun_normal()
self.deep = nnx.Conv(
in_features=3,
out_features=3,
kernel_size=(7, 7),
padding='SAME',
rngs=nnx.Rngs(8439),
use_bias=False,
# disable this to move to float32
param_dtype='float64',
kernel_init=nnx.with_partitioning(init_fn, (None,)),
)
def __call__(self, x: jax.Array):
out = self.deep(x)
return jnn.sigmoid(out)
@nnx.jit
def loss(pred: jax.Array, large: jax.Array) -> jax.Array:
return jnp.mean(optax.squared_error(pred, large))
@nnx.jit
def train_step(
m: Model,
opt: nnx.Optimizer,
small: jax.Array,
large: jax.Array
):
def loss_fn(m: Model):
pred = m(small)
return loss(pred, large)
l, grads = nnx.value_and_grad(loss_fn)(m)
opt.update(grads)
return l
def data(key):
while True:
key, subkey = jax.random.split(key)
large = jax.random.uniform(subkey, (W_LARGE, H_LARGE, C)) * 255
small = jax.image.resize(large, (W_SMALL, H_SMALL, C), 'nearest')
yield small, large
@nnx.jit
def create_sharded_model():
model = Model()
state = nnx.state(model) # The model's state, a pure pytree.
pspecs = nnx.get_partition_spec(state) # Strip out the annotations from state.
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state) # The model is sharded now!
return model
if __name__ == "__main__":
with jax.experimental.enable_x64():
with Mesh(
devices=DEVICES,
axis_names=('gpu',),
) as mesh:
data_sharding = NamedSharding(mesh, PartitionSpec('gpu'))
# Disable this to disable sharding
m = create_sharded_model()
# Use this to disable sharding
# m = Model()
print(f"model sharding: {m.deep.kernel.sharding}")
print(f"model devices: {m.deep.kernel.devices()}")
print(f"model dtype: {m.deep.kernel.dtype}")
base_opt = optax.sgd(1e-2)
opt = nnx.Optimizer(m, base_opt)
for epoch in range(100):
batch_small: list[jax.Array] = list()
batch_large: list[jax.Array] = list()
for small, large in data(jax.random.key(4389)):
# Initial upscale
new_shape = (W_LARGE, H_LARGE, 3)
upscaled = jax.image.resize(small, new_shape, "nearest")
batch_small.append(upscaled)
batch_large.append(large)
if len(batch_small) >= TRAIN_BATCH:
X = jnp.stack(batch_small)
Y = jnp.stack(batch_large)
batch_small = list()
batch_large = list()
# Disable this to disable sharding
X = jax.device_put(X, data_sharding)
Y = jax.device_put(Y, data_sharding)
print(f"X devices: {X.devices()}")
print(f"Y devices: {Y.devices()}")
print(f"X shape: {X.shape}")
print(f"Y shape: {Y.shape}")
print(f"X dtype: {X.dtype}")
print(f"Y dtype: {Y.dtype}")
print(train_step(m, opt, X, Y))
This version fails due to use of sharding plus float32 and a 7x7 convolution:
from flax import nnx
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import optax
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)
W_SMALL=700
H_SMALL=700
W_LARGE=1400
H_LARGE=1400
C=3
DEVICES = jax.devices('gpu')[1:]
# Make sure we're a multiple of # of devices
# This lets shard us by the batch dimension
TRAIN_BATCH=5 * len(DEVICES)
class Model(nnx.Module):
def __init__(self):
init_fn = nnx.initializers.lecun_normal()
self.deep = nnx.Conv(
in_features=3,
out_features=3,
kernel_size=(7, 7),
padding='SAME',
rngs=nnx.Rngs(8439),
use_bias=False,
# disable this to move to float32
# param_dtype='float64',
kernel_init=nnx.with_partitioning(init_fn, (None,)),
)
def __call__(self, x: jax.Array):
out = self.deep(x)
return jnn.sigmoid(out)
@nnx.jit
def loss(pred: jax.Array, large: jax.Array) -> jax.Array:
return jnp.mean(optax.squared_error(pred, large))
@nnx.jit
def train_step(
m: Model,
opt: nnx.Optimizer,
small: jax.Array,
large: jax.Array
):
def loss_fn(m: Model):
pred = m(small)
return loss(pred, large)
l, grads = nnx.value_and_grad(loss_fn)(m)
opt.update(grads)
return l
def data(key):
while True:
key, subkey = jax.random.split(key)
large = jax.random.uniform(subkey, (W_LARGE, H_LARGE, C)) * 255
small = jax.image.resize(large, (W_SMALL, H_SMALL, C), 'nearest')
yield small, large
@nnx.jit
def create_sharded_model():
model = Model()
state = nnx.state(model) # The model's state, a pure pytree.
pspecs = nnx.get_partition_spec(state) # Strip out the annotations from state.
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state) # The model is sharded now!
return model
if __name__ == "__main__":
# with jax.experimental.enable_x64():
with Mesh(
devices=DEVICES,
axis_names=('gpu',),
) as mesh:
data_sharding = NamedSharding(mesh, PartitionSpec('gpu'))
# Disable this to disable sharding
m = create_sharded_model()
# Use this to disable sharding
# m = Model()
print(f"model sharding: {m.deep.kernel.sharding}")
print(f"model devices: {m.deep.kernel.devices()}")
print(f"model dtype: {m.deep.kernel.dtype}")
base_opt = optax.sgd(1e-2)
opt = nnx.Optimizer(m, base_opt)
for epoch in range(100):
batch_small: list[jax.Array] = list()
batch_large: list[jax.Array] = list()
for small, large in data(jax.random.key(4389)):
# Initial upscale
new_shape = (W_LARGE, H_LARGE, 3)
upscaled = jax.image.resize(small, new_shape, "nearest")
batch_small.append(upscaled)
batch_large.append(large)
if len(batch_small) >= TRAIN_BATCH:
X = jnp.stack(batch_small)
Y = jnp.stack(batch_large)
batch_small = list()
batch_large = list()
# Disable this to disable sharding
X = jax.device_put(X, data_sharding)
Y = jax.device_put(Y, data_sharding)
print(f"X devices: {X.devices()}")
print(f"Y devices: {Y.devices()}")
print(f"X shape: {X.shape}")
print(f"Y shape: {Y.shape}")
print(f"X dtype: {X.dtype}")
print(f"Y dtype: {Y.dtype}")
print(train_step(m, opt, X, Y))
This version works due to using a smaller convolution, still with float32:
from flax import nnx
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import optax
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)
W_SMALL=700
H_SMALL=700
W_LARGE=1400
H_LARGE=1400
C=3
DEVICES = jax.devices('gpu')[1:]
# Make sure we're a multiple of # of devices
# This lets shard us by the batch dimension
TRAIN_BATCH=5 * len(DEVICES)
class Model(nnx.Module):
def __init__(self):
init_fn = nnx.initializers.lecun_normal()
self.deep = nnx.Conv(
in_features=3,
out_features=3,
kernel_size=(3, 3),
padding='SAME',
rngs=nnx.Rngs(8439),
use_bias=False,
# disable this to move to float32
# param_dtype='float64',
kernel_init=nnx.with_partitioning(init_fn, (None,)),
)
def __call__(self, x: jax.Array):
out = self.deep(x)
return jnn.sigmoid(out)
@nnx.jit
def loss(pred: jax.Array, large: jax.Array) -> jax.Array:
return jnp.mean(optax.squared_error(pred, large))
@nnx.jit
def train_step(
m: Model,
opt: nnx.Optimizer,
small: jax.Array,
large: jax.Array
):
def loss_fn(m: Model):
pred = m(small)
return loss(pred, large)
l, grads = nnx.value_and_grad(loss_fn)(m)
opt.update(grads)
return l
def data(key):
while True:
key, subkey = jax.random.split(key)
large = jax.random.uniform(subkey, (W_LARGE, H_LARGE, C)) * 255
small = jax.image.resize(large, (W_SMALL, H_SMALL, C), 'nearest')
yield small, large
@nnx.jit
def create_sharded_model():
model = Model()
state = nnx.state(model) # The model's state, a pure pytree.
pspecs = nnx.get_partition_spec(state) # Strip out the annotations from state.
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state) # The model is sharded now!
return model
if __name__ == "__main__":
# with jax.experimental.enable_x64():
with Mesh(
devices=DEVICES,
axis_names=('gpu',),
) as mesh:
data_sharding = NamedSharding(mesh, PartitionSpec('gpu'))
# Disable this to disable sharding
m = create_sharded_model()
# Use this to disable sharding
# m = Model()
print(f"model sharding: {m.deep.kernel.sharding}")
print(f"model devices: {m.deep.kernel.devices()}")
print(f"model dtype: {m.deep.kernel.dtype}")
base_opt = optax.sgd(1e-2)
opt = nnx.Optimizer(m, base_opt)
for epoch in range(100):
batch_small: list[jax.Array] = list()
batch_large: list[jax.Array] = list()
for small, large in data(jax.random.key(4389)):
# Initial upscale
new_shape = (W_LARGE, H_LARGE, 3)
upscaled = jax.image.resize(small, new_shape, "nearest")
batch_small.append(upscaled)
batch_large.append(large)
if len(batch_small) >= TRAIN_BATCH:
X = jnp.stack(batch_small)
Y = jnp.stack(batch_large)
batch_small = list()
batch_large = list()
# Disable this to disable sharding
X = jax.device_put(X, data_sharding)
Y = jax.device_put(Y, data_sharding)
print(f"X devices: {X.devices()}")
print(f"Y devices: {Y.devices()}")
print(f"X shape: {X.shape}")
print(f"Y shape: {Y.shape}")
print(f"X dtype: {X.dtype}")
print(f"Y dtype: {Y.dtype}")
print(train_step(m, opt, X, Y))
This version works due to running on a single device, but with float32 and a 7x7 convolution:
from flax import nnx
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import optax
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)
W_SMALL=700
H_SMALL=700
W_LARGE=1400
H_LARGE=1400
C=3
DEVICES = jax.devices('gpu')[1:]
# Make sure we're a multiple of # of devices
# This lets shard us by the batch dimension
TRAIN_BATCH=5 * len(DEVICES)
class Model(nnx.Module):
def __init__(self):
init_fn = nnx.initializers.lecun_normal()
self.deep = nnx.Conv(
in_features=3,
out_features=3,
kernel_size=(7, 7),
padding='SAME',
rngs=nnx.Rngs(8439),
use_bias=False,
# disable this to move to float32
# param_dtype='float64',
kernel_init=nnx.with_partitioning(init_fn, (None,)),
)
def __call__(self, x: jax.Array):
out = self.deep(x)
return jnn.sigmoid(out)
@nnx.jit
def loss(pred: jax.Array, large: jax.Array) -> jax.Array:
return jnp.mean(optax.squared_error(pred, large))
@nnx.jit
def train_step(
m: Model,
opt: nnx.Optimizer,
small: jax.Array,
large: jax.Array
):
def loss_fn(m: Model):
pred = m(small)
return loss(pred, large)
l, grads = nnx.value_and_grad(loss_fn)(m)
opt.update(grads)
return l
def data(key):
while True:
key, subkey = jax.random.split(key)
large = jax.random.uniform(subkey, (W_LARGE, H_LARGE, C)) * 255
small = jax.image.resize(large, (W_SMALL, H_SMALL, C), 'nearest')
yield small, large
@nnx.jit
def create_sharded_model():
model = Model()
state = nnx.state(model) # The model's state, a pure pytree.
pspecs = nnx.get_partition_spec(state) # Strip out the annotations from state.
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state) # The model is sharded now!
return model
if __name__ == "__main__":
# with jax.experimental.enable_x64():
with Mesh(
devices=DEVICES,
axis_names=('gpu',),
) as mesh:
data_sharding = NamedSharding(mesh, PartitionSpec('gpu'))
# Disable this to disable sharding
# m = create_sharded_model()
# Use this to disable sharding
m = Model()
print(f"model sharding: {m.deep.kernel.sharding}")
print(f"model devices: {m.deep.kernel.devices()}")
print(f"model dtype: {m.deep.kernel.dtype}")
base_opt = optax.sgd(1e-2)
opt = nnx.Optimizer(m, base_opt)
for epoch in range(100):
batch_small: list[jax.Array] = list()
batch_large: list[jax.Array] = list()
for small, large in data(jax.random.key(4389)):
# Initial upscale
new_shape = (W_LARGE, H_LARGE, 3)
upscaled = jax.image.resize(small, new_shape, "nearest")
batch_small.append(upscaled)
batch_large.append(large)
if len(batch_small) >= TRAIN_BATCH:
X = jnp.stack(batch_small)
Y = jnp.stack(batch_large)
batch_small = list()
batch_large = list()
# Disable this to disable sharding
# X = jax.device_put(X, data_sharding)
# Y = jax.device_put(Y, data_sharding)
print(f"X devices: {X.devices()}")
print(f"Y devices: {Y.devices()}")
print(f"X shape: {X.shape}")
print(f"Y shape: {Y.shape}")
print(f"X dtype: {X.dtype}")
print(f"Y dtype: {Y.dtype}")
print(train_step(m, opt, X, Y))
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.1.2
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
device info: NVIDIA RTX 6000 Ada Generation-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='[redacted]', release='6.8.0-48-generic', version='#48~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Mon Oct 7 11:24:13 UTC 2', machine='x86_64')
$ nvidia-smi
Mon Nov 11 17:11:54 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX 6000 Ada Gene... On | 00000000:01:00.0 Off | Off |
| 30% 60C P0 75W / 300W | 450MiB / 49140MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA RTX 6000 Ada Gene... On | 00000000:02:00.0 Off | Off |
| 30% 44C P0 42W / 300W | 450MiB / 49140MiB | 3% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA RTX 6000 Ada Gene... On | 00000000:C1:00.0 Off | Off |
| 30% 60C P0 42W / 300W | 450MiB / 49140MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA RTX 6000 Ada Gene... On | 00000000:E1:00.0 On | Off |
| 30% 52C P0 57W / 300W | 2570MiB / 49140MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 3206 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 1432463 C ../venv/bin/python 426MiB |
| 1 N/A N/A 3206 G /usr/lib/xorg/Xorg 4MiB |
| 1 N/A N/A 1432463 C ../venv/bin/python 426MiB |
| 2 N/A N/A 3206 G /usr/lib/xorg/Xorg 4MiB |
| 2 N/A N/A 1432463 C ../venv/bin/python 426MiB |
| 3 N/A N/A 3206 G /usr/lib/xorg/Xorg 770MiB |
| 3 N/A N/A 3614 G cinnamon 411MiB |
| 3 N/A N/A 4764 G alacritty 22MiB |
| 3 N/A N/A 19845 G /usr/lib/thunderbird/thunderbird 274MiB |
| 3 N/A N/A 1432463 C ../venv/bin/python 426MiB |
| 3 N/A N/A 2642116 G /usr/lib/firefox/firefox 441MiB |
+-----------------------------------------------------------------------------------------+
Description
A sufficiently-large (7x7)
nnx.Conv
withfloat32
parameter dtype, when sharded across multiple devices, generatesnan
, seemingly due to overflow.The
nan
is avoided by making any one of the following changes:param_dtype='float64'
insidewith jax.experimental.enable_x64()
float32 is sufficient for 7x7 convolution when not sharded, suggesting that it ought to work when sharded as well.
The issue can be worked around but all of the workarounds are unsatisfactory in some way, either by reducing the size of the convolution, or requiring double the memory usage, or restricting training to a single device.
Minimal example. This code works because it's moved to float64:
This version fails due to use of sharding plus float32 and a 7x7 convolution:
This version works due to using a smaller convolution, still with float32:
This version works due to running on a single device, but with float32 and a 7x7 convolution:
System info (python version, jaxlib version, accelerator, etc.)