jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.57k stars 2.81k forks source link

7x7 `nnx.Conv` using `float32` parameter dtype overflows(?) to `nan` when sharded #24848

Closed joshhansen closed 1 week ago

joshhansen commented 1 week ago

Description

A sufficiently-large (7x7) nnx.Conv with float32 parameter dtype, when sharded across multiple devices, generates nan, seemingly due to overflow.

The nan is avoided by making any one of the following changes:

  1. Use smaller convolution, e.g. 3x3
  2. Enable param_dtype='float64' inside with jax.experimental.enable_x64()
  3. Disable sharding; run on one device only

float32 is sufficient for 7x7 convolution when not sharded, suggesting that it ought to work when sharded as well.

The issue can be worked around but all of the workarounds are unsatisfactory in some way, either by reducing the size of the convolution, or requiring double the memory usage, or restricting training to a single device.

Minimal example. This code works because it's moved to float64:

from flax import nnx
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import optax

jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)

W_SMALL=700
H_SMALL=700
W_LARGE=1400
H_LARGE=1400
C=3

DEVICES = jax.devices('gpu')[1:]    

# Make sure we're a multiple of # of devices
# This lets shard us by the batch dimension
TRAIN_BATCH=5 * len(DEVICES)

class Model(nnx.Module):
    def __init__(self):
        init_fn = nnx.initializers.lecun_normal()
        self.deep = nnx.Conv(
            in_features=3,
            out_features=3,
            kernel_size=(7, 7),
            padding='SAME',
            rngs=nnx.Rngs(8439),
            use_bias=False,

            # disable this to move to float32
            param_dtype='float64',

            kernel_init=nnx.with_partitioning(init_fn, (None,)),
        )

    def __call__(self, x: jax.Array):
        out = self.deep(x)
        return jnn.sigmoid(out)

@nnx.jit
def loss(pred: jax.Array, large: jax.Array) -> jax.Array:
    return jnp.mean(optax.squared_error(pred, large))

@nnx.jit
def train_step(
    m: Model,
    opt: nnx.Optimizer,
    small: jax.Array,
    large: jax.Array
):
    def loss_fn(m: Model):
        pred = m(small)
        return loss(pred, large)

    l, grads = nnx.value_and_grad(loss_fn)(m)

    opt.update(grads)

    return l

def data(key):
    while True:
        key, subkey = jax.random.split(key)
        large = jax.random.uniform(subkey, (W_LARGE, H_LARGE, C)) * 255
        small = jax.image.resize(large, (W_SMALL, H_SMALL, C), 'nearest')
        yield small, large

@nnx.jit
def create_sharded_model():
    model = Model()
    state = nnx.state(model)                   # The model's state, a pure pytree.
    pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(model, sharded_state)           # The model is sharded now!
    return model

if __name__ == "__main__":
    with jax.experimental.enable_x64():
        with Mesh(
            devices=DEVICES,
            axis_names=('gpu',),
        ) as mesh:
            data_sharding = NamedSharding(mesh, PartitionSpec('gpu'))

            # Disable this to disable sharding
            m = create_sharded_model()

            # Use this to disable sharding
            # m = Model()

            print(f"model sharding: {m.deep.kernel.sharding}")
            print(f"model devices: {m.deep.kernel.devices()}")
            print(f"model dtype: {m.deep.kernel.dtype}")

            base_opt = optax.sgd(1e-2)
            opt = nnx.Optimizer(m, base_opt)
            for epoch in range(100):
                batch_small: list[jax.Array] = list()
                batch_large: list[jax.Array] = list()

                for small, large in data(jax.random.key(4389)):
                    # Initial upscale
                    new_shape = (W_LARGE, H_LARGE, 3)
                    upscaled = jax.image.resize(small, new_shape, "nearest")

                    batch_small.append(upscaled)
                    batch_large.append(large)

                    if len(batch_small) >= TRAIN_BATCH:
                        X = jnp.stack(batch_small)
                        Y = jnp.stack(batch_large)
                        batch_small = list()
                        batch_large = list()

                        # Disable this to disable sharding
                        X = jax.device_put(X, data_sharding)
                        Y = jax.device_put(Y, data_sharding)

                        print(f"X devices: {X.devices()}")
                        print(f"Y devices: {Y.devices()}")

                        print(f"X shape: {X.shape}")
                        print(f"Y shape: {Y.shape}")

                        print(f"X dtype: {X.dtype}")
                        print(f"Y dtype: {Y.dtype}")

                        print(train_step(m, opt, X, Y))

This version fails due to use of sharding plus float32 and a 7x7 convolution:

from flax import nnx
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import optax

jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)

W_SMALL=700
H_SMALL=700
W_LARGE=1400
H_LARGE=1400
C=3

DEVICES = jax.devices('gpu')[1:]    

# Make sure we're a multiple of # of devices
# This lets shard us by the batch dimension
TRAIN_BATCH=5 * len(DEVICES)

class Model(nnx.Module):
    def __init__(self):
        init_fn = nnx.initializers.lecun_normal()
        self.deep = nnx.Conv(
            in_features=3,
            out_features=3,
            kernel_size=(7, 7),
            padding='SAME',
            rngs=nnx.Rngs(8439),
            use_bias=False,

            # disable this to move to float32
            # param_dtype='float64',

            kernel_init=nnx.with_partitioning(init_fn, (None,)),
        )

    def __call__(self, x: jax.Array):
        out = self.deep(x)
        return jnn.sigmoid(out)

@nnx.jit
def loss(pred: jax.Array, large: jax.Array) -> jax.Array:
    return jnp.mean(optax.squared_error(pred, large))

@nnx.jit
def train_step(
    m: Model,
    opt: nnx.Optimizer,
    small: jax.Array,
    large: jax.Array
):
    def loss_fn(m: Model):
        pred = m(small)
        return loss(pred, large)

    l, grads = nnx.value_and_grad(loss_fn)(m)

    opt.update(grads)

    return l

def data(key):
    while True:
        key, subkey = jax.random.split(key)
        large = jax.random.uniform(subkey, (W_LARGE, H_LARGE, C)) * 255
        small = jax.image.resize(large, (W_SMALL, H_SMALL, C), 'nearest')
        yield small, large

@nnx.jit
def create_sharded_model():
    model = Model()
    state = nnx.state(model)                   # The model's state, a pure pytree.
    pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(model, sharded_state)           # The model is sharded now!
    return model

if __name__ == "__main__":
    # with jax.experimental.enable_x64():
    with Mesh(
        devices=DEVICES,
        axis_names=('gpu',),
    ) as mesh:
        data_sharding = NamedSharding(mesh, PartitionSpec('gpu'))

        # Disable this to disable sharding
        m = create_sharded_model()

        # Use this to disable sharding
        # m = Model()

        print(f"model sharding: {m.deep.kernel.sharding}")
        print(f"model devices: {m.deep.kernel.devices()}")
        print(f"model dtype: {m.deep.kernel.dtype}")

        base_opt = optax.sgd(1e-2)
        opt = nnx.Optimizer(m, base_opt)
        for epoch in range(100):
            batch_small: list[jax.Array] = list()
            batch_large: list[jax.Array] = list()

            for small, large in data(jax.random.key(4389)):
                # Initial upscale
                new_shape = (W_LARGE, H_LARGE, 3)
                upscaled = jax.image.resize(small, new_shape, "nearest")

                batch_small.append(upscaled)
                batch_large.append(large)

                if len(batch_small) >= TRAIN_BATCH:
                    X = jnp.stack(batch_small)
                    Y = jnp.stack(batch_large)
                    batch_small = list()
                    batch_large = list()

                    # Disable this to disable sharding
                    X = jax.device_put(X, data_sharding)
                    Y = jax.device_put(Y, data_sharding)

                    print(f"X devices: {X.devices()}")
                    print(f"Y devices: {Y.devices()}")

                    print(f"X shape: {X.shape}")
                    print(f"Y shape: {Y.shape}")

                    print(f"X dtype: {X.dtype}")
                    print(f"Y dtype: {Y.dtype}")

                    print(train_step(m, opt, X, Y))

This version works due to using a smaller convolution, still with float32:

from flax import nnx
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import optax

jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)

W_SMALL=700
H_SMALL=700
W_LARGE=1400
H_LARGE=1400
C=3

DEVICES = jax.devices('gpu')[1:]    

# Make sure we're a multiple of # of devices
# This lets shard us by the batch dimension
TRAIN_BATCH=5 * len(DEVICES)

class Model(nnx.Module):
    def __init__(self):
        init_fn = nnx.initializers.lecun_normal()
        self.deep = nnx.Conv(
            in_features=3,
            out_features=3,
            kernel_size=(3, 3),
            padding='SAME',
            rngs=nnx.Rngs(8439),
            use_bias=False,

            # disable this to move to float32
            # param_dtype='float64',

            kernel_init=nnx.with_partitioning(init_fn, (None,)),
        )

    def __call__(self, x: jax.Array):
        out = self.deep(x)
        return jnn.sigmoid(out)

@nnx.jit
def loss(pred: jax.Array, large: jax.Array) -> jax.Array:
    return jnp.mean(optax.squared_error(pred, large))

@nnx.jit
def train_step(
    m: Model,
    opt: nnx.Optimizer,
    small: jax.Array,
    large: jax.Array
):
    def loss_fn(m: Model):
        pred = m(small)
        return loss(pred, large)

    l, grads = nnx.value_and_grad(loss_fn)(m)

    opt.update(grads)

    return l

def data(key):
    while True:
        key, subkey = jax.random.split(key)
        large = jax.random.uniform(subkey, (W_LARGE, H_LARGE, C)) * 255
        small = jax.image.resize(large, (W_SMALL, H_SMALL, C), 'nearest')
        yield small, large

@nnx.jit
def create_sharded_model():
    model = Model()
    state = nnx.state(model)                   # The model's state, a pure pytree.
    pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(model, sharded_state)           # The model is sharded now!
    return model

if __name__ == "__main__":
    # with jax.experimental.enable_x64():
    with Mesh(
        devices=DEVICES,
        axis_names=('gpu',),
    ) as mesh:
        data_sharding = NamedSharding(mesh, PartitionSpec('gpu'))

        # Disable this to disable sharding
        m = create_sharded_model()

        # Use this to disable sharding
        # m = Model()

        print(f"model sharding: {m.deep.kernel.sharding}")
        print(f"model devices: {m.deep.kernel.devices()}")
        print(f"model dtype: {m.deep.kernel.dtype}")

        base_opt = optax.sgd(1e-2)
        opt = nnx.Optimizer(m, base_opt)
        for epoch in range(100):
            batch_small: list[jax.Array] = list()
            batch_large: list[jax.Array] = list()

            for small, large in data(jax.random.key(4389)):
                # Initial upscale
                new_shape = (W_LARGE, H_LARGE, 3)
                upscaled = jax.image.resize(small, new_shape, "nearest")

                batch_small.append(upscaled)
                batch_large.append(large)

                if len(batch_small) >= TRAIN_BATCH:
                    X = jnp.stack(batch_small)
                    Y = jnp.stack(batch_large)
                    batch_small = list()
                    batch_large = list()

                    # Disable this to disable sharding
                    X = jax.device_put(X, data_sharding)
                    Y = jax.device_put(Y, data_sharding)

                    print(f"X devices: {X.devices()}")
                    print(f"Y devices: {Y.devices()}")

                    print(f"X shape: {X.shape}")
                    print(f"Y shape: {Y.shape}")

                    print(f"X dtype: {X.dtype}")
                    print(f"Y dtype: {Y.dtype}")

                    print(train_step(m, opt, X, Y))

This version works due to running on a single device, but with float32 and a 7x7 convolution:

from flax import nnx
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import optax

jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)

W_SMALL=700
H_SMALL=700
W_LARGE=1400
H_LARGE=1400
C=3

DEVICES = jax.devices('gpu')[1:]    

# Make sure we're a multiple of # of devices
# This lets shard us by the batch dimension
TRAIN_BATCH=5 * len(DEVICES)

class Model(nnx.Module):
    def __init__(self):
        init_fn = nnx.initializers.lecun_normal()
        self.deep = nnx.Conv(
            in_features=3,
            out_features=3,
            kernel_size=(7, 7),
            padding='SAME',
            rngs=nnx.Rngs(8439),
            use_bias=False,

            # disable this to move to float32
            # param_dtype='float64',

            kernel_init=nnx.with_partitioning(init_fn, (None,)),
        )

    def __call__(self, x: jax.Array):
        out = self.deep(x)
        return jnn.sigmoid(out)

@nnx.jit
def loss(pred: jax.Array, large: jax.Array) -> jax.Array:
    return jnp.mean(optax.squared_error(pred, large))

@nnx.jit
def train_step(
    m: Model,
    opt: nnx.Optimizer,
    small: jax.Array,
    large: jax.Array
):
    def loss_fn(m: Model):
        pred = m(small)
        return loss(pred, large)

    l, grads = nnx.value_and_grad(loss_fn)(m)

    opt.update(grads)

    return l

def data(key):
    while True:
        key, subkey = jax.random.split(key)
        large = jax.random.uniform(subkey, (W_LARGE, H_LARGE, C)) * 255
        small = jax.image.resize(large, (W_SMALL, H_SMALL, C), 'nearest')
        yield small, large

@nnx.jit
def create_sharded_model():
    model = Model()
    state = nnx.state(model)                   # The model's state, a pure pytree.
    pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(model, sharded_state)           # The model is sharded now!
    return model

if __name__ == "__main__":
    # with jax.experimental.enable_x64():
    with Mesh(
        devices=DEVICES,
        axis_names=('gpu',),
    ) as mesh:
        data_sharding = NamedSharding(mesh, PartitionSpec('gpu'))

        # Disable this to disable sharding
        # m = create_sharded_model()

        # Use this to disable sharding
        m = Model()

        print(f"model sharding: {m.deep.kernel.sharding}")
        print(f"model devices: {m.deep.kernel.devices()}")
        print(f"model dtype: {m.deep.kernel.dtype}")

        base_opt = optax.sgd(1e-2)
        opt = nnx.Optimizer(m, base_opt)
        for epoch in range(100):
            batch_small: list[jax.Array] = list()
            batch_large: list[jax.Array] = list()

            for small, large in data(jax.random.key(4389)):
                # Initial upscale
                new_shape = (W_LARGE, H_LARGE, 3)
                upscaled = jax.image.resize(small, new_shape, "nearest")

                batch_small.append(upscaled)
                batch_large.append(large)

                if len(batch_small) >= TRAIN_BATCH:
                    X = jnp.stack(batch_small)
                    Y = jnp.stack(batch_large)
                    batch_small = list()
                    batch_large = list()

                    # Disable this to disable sharding
                    # X = jax.device_put(X, data_sharding)
                    # Y = jax.device_put(Y, data_sharding)

                    print(f"X devices: {X.devices()}")
                    print(f"Y devices: {Y.devices()}")

                    print(f"X shape: {X.shape}")
                    print(f"Y shape: {Y.shape}")

                    print(f"X dtype: {X.dtype}")
                    print(f"Y dtype: {Y.dtype}")

                    print(train_step(m, opt, X, Y))

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
device info: NVIDIA RTX 6000 Ada Generation-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='[redacted]', release='6.8.0-48-generic', version='#48~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Mon Oct  7 11:24:13 UTC 2', machine='x86_64')

$ nvidia-smi
Mon Nov 11 17:11:54 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX 6000 Ada Gene...    On  |   00000000:01:00.0 Off |                  Off |
| 30%   60C    P0             75W /  300W |     450MiB /  49140MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX 6000 Ada Gene...    On  |   00000000:02:00.0 Off |                  Off |
| 30%   44C    P0             42W /  300W |     450MiB /  49140MiB |      3%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA RTX 6000 Ada Gene...    On  |   00000000:C1:00.0 Off |                  Off |
| 30%   60C    P0             42W /  300W |     450MiB /  49140MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA RTX 6000 Ada Gene...    On  |   00000000:E1:00.0  On |                  Off |
| 30%   52C    P0             57W /  300W |    2570MiB /  49140MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      3206      G   /usr/lib/xorg/Xorg                              4MiB |
|    0   N/A  N/A   1432463      C   ../venv/bin/python                            426MiB |
|    1   N/A  N/A      3206      G   /usr/lib/xorg/Xorg                              4MiB |
|    1   N/A  N/A   1432463      C   ../venv/bin/python                            426MiB |
|    2   N/A  N/A      3206      G   /usr/lib/xorg/Xorg                              4MiB |
|    2   N/A  N/A   1432463      C   ../venv/bin/python                            426MiB |
|    3   N/A  N/A      3206      G   /usr/lib/xorg/Xorg                            770MiB |
|    3   N/A  N/A      3614      G   cinnamon                                      411MiB |
|    3   N/A  N/A      4764      G   alacritty                                      22MiB |
|    3   N/A  N/A     19845      G   /usr/lib/thunderbird/thunderbird              274MiB |
|    3   N/A  N/A   1432463      C   ../venv/bin/python                            426MiB |
|    3   N/A  N/A   2642116      G   /usr/lib/firefox/firefox                      441MiB |
+-----------------------------------------------------------------------------------------+
joshhansen commented 1 week ago

Apologies, didn't realize flax is in a separate repo. I reposted this there.