google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k stars 608 forks source link

Problem in usage of multi-gpu (Four RTX 4090) for FLAX #3931

Open psj1866 opened 1 month ago

psj1866 commented 1 month ago

Dear FLAX community,

System information

image

Problem you have encountered:

As shown in the image above, my server computer is equipped with 4 RTX 4090 GPUs. I tried to run batch-training through multi-gpu, but it didn't work with error message below. For me, it seems like the problem comes from the NVIDIA GPU, not from Python.

What you expected to happen:

I want to use multi-gpu for batch-training in FLAX in my server computer environment. How can I fix my code or re-build the environment? (I am quite new to Linux...)

Logs, error messages, etc:

Error message is as follows:

cgroup-gpux4:1539012:1539157 [3] NCCL INFO Bootstrap : Using enp36s0f1:192.168.1.45<0>
cgroup-gpux4:1539012:1539157 [3] NCCL INFO NET/Plugin : dlerror=libnccl-net.so: cannot open shared object file: No such file or directory No plugin found (libnccl-net.so), using internal implementation
cgroup-gpux4:1539012:1539157 [0] NCCL INFO cudaDriverVersion 12030
NCCL version 2.20.5+cuda12.4
cgroup-gpux4:1539012:1539157 [0] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1539012:1539157 [1] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1539012:1539157 [2] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1539012:1539157 [3] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1539012:1539300 [0] NCCL INFO NET/IB : No device found.
cgroup-gpux4:1539012:1539300 [0] NCCL INFO NET/Socket : Using [0]enp36s0f1:192.168.1.45<0>
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Using network Socket
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Using network Socket
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Using network Socket
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Using network Socket
cgroup-gpux4:1539012:1539302 [2] NCCL INFO comm 0x7f3b64264db0 rank 2 nranks 4 cudaDev 2 nvmlDev 2 busId 41000 commId 0x80ed69e629feb92e - Init START
cgroup-gpux4:1539012:1539301 [1] NCCL INFO comm 0x7f3b6425b6d0 rank 1 nranks 4 cudaDev 1 nvmlDev 1 busId 2a000 commId 0x80ed69e629feb92e - Init START
cgroup-gpux4:1539012:1539300 [0] NCCL INFO comm 0x7f3b64253c50 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId 1000 commId 0x80ed69e629feb92e - Init START
cgroup-gpux4:1539012:1539303 [3] NCCL INFO comm 0x7f3b6426e360 rank 3 nranks 4 cudaDev 3 nvmlDev 3 busId 61000 commId 0x80ed69e629feb92e - Init START
cgroup-gpux4:1539012:1539302 [2] NCCL INFO NVLS multicast support is not available on dev 2
cgroup-gpux4:1539012:1539303 [3] NCCL INFO NVLS multicast support is not available on dev 3
cgroup-gpux4:1539012:1539300 [0] NCCL INFO NVLS multicast support is not available on dev 0
cgroup-gpux4:1539012:1539301 [1] NCCL INFO NVLS multicast support is not available on dev 1
cgroup-gpux4:1539012:1539300 [0] NCCL INFO comm 0x7f3b64253c50 rank 0 nRanks 4 nNodes 1 localRanks 4 localRank 0 MNNVL 0
cgroup-gpux4:1539012:1539303 [3] NCCL INFO comm 0x7f3b6426e360 rank 3 nRanks 4 nNodes 1 localRanks 4 localRank 3 MNNVL 0
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Channel 00/04 :    0   1   2   3
cgroup-gpux4:1539012:1539301 [1] NCCL INFO comm 0x7f3b6425b6d0 rank 1 nRanks 4 nNodes 1 localRanks 4 localRank 1 MNNVL 0
cgroup-gpux4:1539012:1539302 [2] NCCL INFO comm 0x7f3b64264db0 rank 2 nRanks 4 nNodes 1 localRanks 4 localRank 2 MNNVL 0
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Channel 01/04 :    0   1   2   3
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Trees [0] 2/-1/-1->1->0 [1] 2/-1/-1->1->0 [2] 2/-1/-1->1->0 [3] 2/-1/-1->1->0
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Channel 02/04 :    0   1   2   3
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Channel 03/04 :    0   1   2   3
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Trees [0] 3/-1/-1->2->1 [1] 3/-1/-1->2->1 [2] 3/-1/-1->2->1 [3] 3/-1/-1->2->1
cgroup-gpux4:1539012:1539301 [1] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1539012:1539302 [2] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] 1/-1/-1->0->-1 [3] 1/-1/-1->0->-1
cgroup-gpux4:1539012:1539300 [0] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Trees [0] -1/-1/-1->3->2 [1] -1/-1/-1->3->2 [2] -1/-1/-1->3->2 [3] -1/-1/-1->3->2
cgroup-gpux4:1539012:1539303 [3] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Channel 00/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Channel 00/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Channel 00/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Channel 01/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Channel 01/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Channel 01/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Channel 02/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Channel 02/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Channel 02/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Channel 03/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Channel 03/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Channel 03/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Channel 02/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Channel 03/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Connected all rings
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Connected all rings
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Connected all rings
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Channel 00/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Connected all rings
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Channel 01/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Channel 02/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Channel 03/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Channel 00/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Channel 01/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Channel 02/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Channel 02/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Channel 03/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Channel 03/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1539012:1539303 [3] NCCL INFO Connected all trees
cgroup-gpux4:1539012:1539303 [3] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1539012:1539303 [3] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1539012:1539300 [0] NCCL INFO Connected all trees
cgroup-gpux4:1539012:1539302 [2] NCCL INFO Connected all trees
cgroup-gpux4:1539012:1539302 [2] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1539012:1539302 [2] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1539012:1539300 [0] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1539012:1539300 [0] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1539012:1539301 [1] NCCL INFO Connected all trees
cgroup-gpux4:1539012:1539301 [1] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1539012:1539301 [1] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1539012:1539303 [3] NCCL INFO comm 0x7f3b6426e360 rank 3 nranks 4 cudaDev 3 nvmlDev 3 busId 61000 commId 0x80ed69e629feb92e - Init COMPLETE
cgroup-gpux4:1539012:1539301 [1] NCCL INFO comm 0x7f3b6425b6d0 rank 1 nranks 4 cudaDev 1 nvmlDev 1 busId 2a000 commId 0x80ed69e629feb92e - Init COMPLETE
cgroup-gpux4:1539012:1539302 [2] NCCL INFO comm 0x7f3b64264db0 rank 2 nranks 4 cudaDev 2 nvmlDev 2 busId 41000 commId 0x80ed69e629feb92e - Init COMPLETE
cgroup-gpux4:1539012:1539300 [0] NCCL INFO comm 0x7f3b64253c50 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId 1000 commId 0x80ed69e629feb92e - Init COMPLETE

Steps to reproduce:

I followed this benchmark code: (https://colab.research.google.com/drive/1hXns2b6T8T393zSrKCSoUktye1YlSe8U?usp=sharing#scrollTo=oKcRiQ89xQkF) and fixed several issues. The code used for my server is as follows:

from functools import partial
import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state
from flax.struct import PyTreeNode

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

DEVICE_COUNT = jax.device_count()

import os
os.environ['NCCL_DEBUG'] = 'INFO'

import pickle

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x, training):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.Dropout(0.2, deterministic=not training)(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.Dropout(0.2, deterministic=not training)(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

class Metrics(PyTreeNode):
  count: int
  acc_total: float
  loss_total: float

  @classmethod
  def new(cls):
    return cls(
        count=jnp.array(0, jnp.int32),
        acc_total=jnp.array(0.0, jnp.float32),
        loss_total=jnp.array(0.0, jnp.float32),
    )

  def update(self, *, logits, labels):
    return self.replace(
        count=self.count + labels.size,
        acc_total=self.acc_total + jnp.sum(jnp.argmax(logits, -1) == labels),
        loss_total=self.loss_total + cross_entropy_loss(logits=logits, labels=labels),
    )

  def compute(self):
    return {
      'loss': self.loss_total / self.count,
      'accuracy': self.acc_total / self.count,
  }

def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

@jax.pmap
def init_step(key):
  x = jnp.ones([1, 28, 28, 1])
  variables = module.init(key, x, training=False)
  opt_state = optimizer.init(variables['params'])
  metrics = Metrics.new()
  return variables, opt_state, metrics

@partial(jax.pmap, axis_name="device", out_axes=(None, 0, 0, 0))
def train_step(key, batch, variables, opt_state, metrics):
  params = variables['params']

  def loss_fn(params, variables):
    variables['params'] = params
    logits, updates = module.apply(variables, batch['image'], training=True, 
                                   mutable='batch_stats', rngs={'dropout': key})
    variables['updates'] = updates
    loss = cross_entropy_loss(logits=logits, labels=batch['label'])
    return loss, (logits, variables)

  # compute predictions and gradients
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, (logits, variables)), grads = grad_fn(params, variables)

  # sync gradients
  grads = jax.lax.pmean(grads, "device")

  # update params
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  variables['params'] = params

  # sync batch stats
  batch_stats = jax.lax.pmean(variables['batch_stats'], "device")
  variables['batch_stats'] = batch_stats

  # compute metrics
  metrics = metrics.update(logits=logits, labels=batch['label'])
  logs = jax.lax.psum(metrics, "device").compute() # <== sync metrics

  return logs, variables, opt_state, metrics

@partial(jax.pmap, axis_name="device", out_axes=(None, 0))
def eval_step(batch, variables, metrics):
  logits = module.apply(variables, batch['image'], training=False)
  metrics = metrics.update(logits=logits, labels=batch['label'])
  logs = jax.lax.psum(metrics, "device").compute() # <== sync metrics
  return logs, metrics

def train_epoch(key, variables, opt_state, metrics, train_ds, batch_size, epoch):
  batch_size *= DEVICE_COUNT

  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(key, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    step_key, key = jax.random.split(key)

    # split step_key and reshape data
    step_key = jax.random.split(step_key, DEVICE_COUNT)
    batch = jax.tree_map(lambda x: x.reshape((DEVICE_COUNT, -1, *x.shape[1:])), batch)

    logs, variables, opt_state, metrics = train_step(step_key, batch, variables, opt_state, metrics)
    batch_metrics.append(logs)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  print(f'train epoch: {epoch}, loss: {epoch_metrics_np["loss"]:.8f}, accuracy: {epoch_metrics_np["accuracy"] * 100:.2f}')

  return variables, opt_state

def eval_model(test_ds, variables, metrics):
  logs, metrics = eval_step(test_ds, variables, metrics)
  logs = jax.device_get(logs)
  summary = jax.tree_map(lambda x: x.item(), logs)
  return summary['loss'], summary['accuracy']
learning_rate = 0.01
momentum = 0.9

module = CNN()
optimizer = optax.sgd(learning_rate, momentum)
'''
train_ds, test_ds = get_datasets()
N = 16
train_ds = {'image': train_ds['image'][:N], 'label' : train_ds['label'][:N]}
test_ds = {'image': test_ds['image'][:N], 'label' : test_ds['label'][:N]}
'''
with open('train.pickle', 'rb') as f:
    train_ds = pickle.load(f)

with open('test.pickle', 'rb') as f:
    test_ds = pickle.load(f)
key = jax.random.PRNGKey(0)
key, init_key = jax.random.split(key)

# replicate init_key (same initial weights on all devices)
init_key = jnp.tile(init_key[None], (DEVICE_COUNT, 1))

variables, opt_state, metrics0 = init_step(init_key)
jax.tree_map(lambda x: x.shape, variables)
num_epochs = 10
batch_size = 4

for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  key, epoch_key = jax.random.split(key)
  # Run an optimization step over a training batch
  variables, opt_state = train_epoch(epoch_key, variables, opt_state, metrics0, train_ds, batch_size, epoch)
  # Evaluate on the test set after each training epoch
  batch = jax.tree_map(lambda x: x.reshape((DEVICE_COUNT, -1, *x.shape[1:])), test_ds)
  test_loss, test_accuracy = eval_model(batch, variables, metrics0)
  print(f' test epoch: {epoch}, loss: {test_loss:.8f}, accuracy: {test_accuracy * 100:.2f}')

(In this paragraph, the error message appears)

It seems that there are lots of people suffering from multi-gpu environment with RTX 4090??

Thanks for reading!

IvyZX commented 1 month ago

The error message you posted all looks like normal INFO printouts - is there a more specific error message or stack trace, or did the program just crashed after these printouts?

In general, hardware issues like multi-gpu are more likely rooted in JAX, as Flax rarely directly touch lower level APIs. I'd also recommend trying some smaller, pure-JAX code (like from this multi-device guide or other JAX website sample code) to pinpoint the error to more specific lines.

psj1866 commented 1 month ago

Thanks for your reply! Well, I tried simple code using pmap, but not exactly sure that this is the right one for comparison. The code can be found in this link: https://github.com/google/flax/discussions/2121

import jax
import jax.numpy as jnp

from flax import linen as nn
from flax import jax_utils
import optax
from flax.training.train_state import TrainState

model = nn.Dense(1)
x = jnp.ones((jax.device_count(), 3))
params = model.init(jax.random.PRNGKey(0), x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(
    apply_fn=model.apply, params=params, tx=tx,
)
state = jax_utils.replicate(state)

def loss_fn(state, x):
    return (model.apply(state.params, x) ** 2.0).mean()

jax.pmap(loss_fn)(state, x)

Actually, the pmap worked without any error in this code.

For your question is there a more specific error message or stack trace, or did the program just crashed after these printouts?, the program just crashed after this printout without further progress (I have waited more than 6 hours but nothing gained)

I found that since RTX 4090 serives, nvlink is not equipped. Would this be the reason for this error?

Thanks!

IvyZX commented 1 month ago

From your description it sounds like the program is blocked, instead of fail and exit immediately? If blocked, it might be that the gpu devices (or their cpu hosts?) are out of sync. Maybe try to run those jax.lax.p.* collectives in your pmapped pure-JAX function?

Another thing worth doing is adding a ton of prints in your code to bisect which line it is blocked at.

Also just FYI, jax.pmap is outdated and JAX generally recommend using jax.shard_map for per-device code. Or if you just want to try MNIST on Flax, you can use a version without pmap at quickstart.

psj1866 commented 1 month ago

Thanks for your reply! First, I'd like to apologize that my uploaded code was too long and did not specify the error point. By adding some print function in functions, I found that error occurs when the function train_step is returning its output.

@partial(jax.pmap, axis_name="device", out_axes=(None, 0, 0, 0))
**def train_step(key, batch, variables, opt_state, metrics):**
  params = variables['params']

  def loss_fn(params, variables):
    variables['params'] = params
    logits, updates = module.apply(variables, batch['image'], training=True, 
                                   mutable='batch_stats', rngs={'dropout': key})
    variables['updates'] = updates
    loss = cross_entropy_loss(logits=logits, labels=batch['label'])
    return loss, (logits, variables)

  # compute predictions and gradients
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, (logits, variables)), grads = grad_fn(params, variables)
  print(0)
  # sync gradients
  grads = jax.lax.pmean(grads, "device")
  print(1)

  # update params
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  variables['params'] = params
  print(2)

  # sync batch stats
  batch_stats = jax.lax.pmean(variables['batch_stats'], "device")
  variables['batch_stats'] = batch_stats
  print(3)

  # compute metrics
  metrics = metrics.update(logits=logits, labels=batch['label'])
  logs = jax.lax.psum(metrics, "device").compute() # <== sync metrics
  print(4)

  return logs, variables, opt_state, metrics
num_epochs = 10
batch_size = 2

for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  key, epoch_key = jax.random.split(key)
  # Run an optimization step over a training batch
  **print('a')**
  **variables, opt_state = train_epoch(epoch_key, variables, opt_state, metrics0, train_ds, batch_size, epoch)
  # Evaluate on the test set after each training epoch
  print('b')**
  batch = jax.tree_map(lambda x: x.reshape((DEVICE_COUNT, -1, *x.shape[1:])), test_ds)
  test_loss, test_accuracy = eval_model(batch, variables, metrics0)
  print(f' test epoch: {epoch}, loss: {test_loss:.8f}, accuracy: {test_accuracy * 100:.2f}')

The info message is as follows:

a
/tmp/ipykernel_1856689/328536960.py:131: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  batch = jax.tree_map(lambda x: x.reshape((DEVICE_COUNT, -1, *x.shape[1:])), batch)
0
1
2
3
4
cgroup-gpux4:1856689:1856816 [1] NCCL INFO Bootstrap : Using enp36s0f1:192.168.1.45<0>
cgroup-gpux4:1856689:1856816 [1] NCCL INFO NET/Plugin : dlerror=libnccl-net.so: cannot open shared object file: No such file or directory No plugin found (libnccl-net.so), using internal implementation
cgroup-gpux4:1856689:1856816 [0] NCCL INFO cudaDriverVersion 12030
NCCL version 2.20.5+cuda12.4
cgroup-gpux4:1856689:1856816 [0] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1856689:1856816 [1] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1856689:1856816 [2] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1856689:1856816 [3] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1856689:1856962 [3] NCCL INFO NET/IB : No device found.
cgroup-gpux4:1856689:1856962 [3] NCCL INFO NET/Socket : Using [0]enp36s0f1:192.168.1.45<0>
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Using network Socket
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Using network Socket
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Using network Socket
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Using network Socket
cgroup-gpux4:1856689:1856962 [3] NCCL INFO comm 0x7f72c8bd62d0 rank 3 nranks 4 cudaDev 3 nvmlDev 3 busId 61000 commId 0x438f56a3cde03111 - Init START
cgroup-gpux4:1856689:1856959 [0] NCCL INFO comm 0x7f72c8bbbbc0 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId 1000 commId 0x438f56a3cde03111 - Init START
cgroup-gpux4:1856689:1856960 [1] NCCL INFO comm 0x7f72c8bc3640 rank 1 nranks 4 cudaDev 1 nvmlDev 1 busId 2a000 commId 0x438f56a3cde03111 - Init START
cgroup-gpux4:1856689:1856961 [2] NCCL INFO comm 0x7f72c8bccd20 rank 2 nranks 4 cudaDev 2 nvmlDev 2 busId 41000 commId 0x438f56a3cde03111 - Init START
cgroup-gpux4:1856689:1856961 [2] NCCL INFO NVLS multicast support is not available on dev 2
cgroup-gpux4:1856689:1856959 [0] NCCL INFO NVLS multicast support is not available on dev 0
cgroup-gpux4:1856689:1856960 [1] NCCL INFO NVLS multicast support is not available on dev 1
cgroup-gpux4:1856689:1856962 [3] NCCL INFO NVLS multicast support is not available on dev 3
cgroup-gpux4:1856689:1856961 [2] NCCL INFO comm 0x7f72c8bccd20 rank 2 nRanks 4 nNodes 1 localRanks 4 localRank 2 MNNVL 0
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Trees [0] 3/-1/-1->2->1 [1] 3/-1/-1->2->1 [2] 3/-1/-1->2->1 [3] 3/-1/-1->2->1
cgroup-gpux4:1856689:1856960 [1] NCCL INFO comm 0x7f72c8bc3640 rank 1 nRanks 4 nNodes 1 localRanks 4 localRank 1 MNNVL 0
cgroup-gpux4:1856689:1856961 [2] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1856689:1856962 [3] NCCL INFO comm 0x7f72c8bd62d0 rank 3 nRanks 4 nNodes 1 localRanks 4 localRank 3 MNNVL 0
cgroup-gpux4:1856689:1856959 [0] NCCL INFO comm 0x7f72c8bbbbc0 rank 0 nRanks 4 nNodes 1 localRanks 4 localRank 0 MNNVL 0
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Trees [0] 2/-1/-1->1->0 [1] 2/-1/-1->1->0 [2] 2/-1/-1->1->0 [3] 2/-1/-1->1->0
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Trees [0] -1/-1/-1->3->2 [1] -1/-1/-1->3->2 [2] -1/-1/-1->3->2 [3] -1/-1/-1->3->2
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 00/04 :    0   1   2   3
cgroup-gpux4:1856689:1856960 [1] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 01/04 :    0   1   2   3
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 02/04 :    0   1   2   3
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 03/04 :    0   1   2   3
cgroup-gpux4:1856689:1856962 [3] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] 1/-1/-1->0->-1 [3] 1/-1/-1->0->-1
cgroup-gpux4:1856689:1856959 [0] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 00/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 00/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 01/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 01/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 02/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 02/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 00/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 03/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 03/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 02/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 01/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 03/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 02/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 03/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Connected all rings
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Connected all rings
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Connected all rings
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Connected all rings
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 00/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 01/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 02/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 03/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 00/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 01/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 02/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 02/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 03/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 03/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Connected all trees
cgroup-gpux4:1856689:1856959 [0] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1856689:1856959 [0] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Connected all trees
cgroup-gpux4:1856689:1856960 [1] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1856689:1856960 [1] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Connected all trees
cgroup-gpux4:1856689:1856962 [3] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1856689:1856962 [3] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Connected all trees
cgroup-gpux4:1856689:1856961 [2] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1856689:1856961 [2] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO comm 0x7f72c8bccd20 rank 2 nranks 4 cudaDev 2 nvmlDev 2 busId 41000 commId 0x438f56a3cde03111 - Init COMPLETE
cgroup-gpux4:1856689:1856959 [0] NCCL INFO comm 0x7f72c8bbbbc0 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId 1000 commId 0x438f56a3cde03111 - Init COMPLETE
cgroup-gpux4:1856689:1856960 [1] NCCL INFO comm 0x7f72c8bc3640 rank 1 nranks 4 cudaDev 1 nvmlDev 1 busId 2a000 commId 0x438f56a3cde03111 - Init COMPLETE
cgroup-gpux4:1856689:1856962 [3] NCCL INFO comm 0x7f72c8bd62d0 rank 3 nranks 4 cudaDev 3 nvmlDev 3 busId 61000 commId 0x438f56a3cde03111 - Init COMPLETE
/tmp/ipykernel_1856689/328536960.py:131: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  batch = jax.tree_map(lambda x: x.reshape((DEVICE_COUNT, -1, *x.shape[1:])), batch)
0
1
2
3
4

... and no further progress.

Thanks!

IvyZX commented 1 month ago

Do you have any printout in train_epoch function to pinpoint the line of blockage? We would really benefit from a smaller code that can repro the problem and narrow down our search. If it seems to be from train_step, maybe calling it directly with some fake input?

psj1866 commented 1 month ago

Thanks for your comment! Following your advice, I would like to start with the very fundamental problem: replication of the train state.

class MLP(nn.Module):

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=2)(x)
    return x
x = np.ones((jax.device_count(), 3))
y = np.zeros((jax.device_count(), 3))
print(x, y)

model = MLP()
params = model.init(jax.random.PRNGKey(0), x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
print(state)
print('###########################################################################')
state_rep = jax_utils.replicate(state)
print(state_rep)

def loss_fn(state, x, y):
    print((model.apply(state.params, x)))
    return (model.apply(state.params, x))

jax.pmap(loss_fn)(state_rep, x, y)

Output:

[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]] [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
TrainState(step=0, apply_fn=<bound method Module.apply of MLP()>, params={'params': {'Dense_0': {'kernel': Array([[-0.3194899 ,  0.9700081 ],
       [-1.1898965 , -0.02842531],
       [ 0.05931681,  0.38353   ]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f4190147130>, update=<function chain.<locals>.update_fn at 0x7f4190198670>), opt_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu={'params': {'Dense_0': {'bias': Array([0., 0.], dtype=float32), 'kernel': Array([[0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)}}}, nu={'params': {'Dense_0': {'bias': Array([0., 0.], dtype=float32), 'kernel': Array([[0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)}}}), EmptyState()))
###########################################################################
TrainState(step=Array([         0, 1065353216, 1065353216, 1065353216],      dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of MLP()>, params={'params': {'Dense_0': {'bias': Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), 'kernel': Array([[[-0.3194899 ,  0.9700081 ],
        [-1.1898965 , -0.02842531],
        [ 0.05931681,  0.38353   ]],

       [[ 0.        ,  0.        ],
        [ 0.        ,  0.        ],
        [ 0.        ,  0.        ]],

       [[ 0.        ,  0.        ],
        [ 0.        ,  0.        ],
        [ 0.        ,  0.        ]],

       [[ 0.        ,  0.        ],
        [ 0.        ,  0.        ],
        [ 0.        ,  0.        ]]], dtype=float32)}}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f4190147130>, update=<function chain.<locals>.update_fn at 0x7f4190198670>), opt_state=(ScaleByAdamState(count=Array([         0, 1065353216, 1065353216, 1065353216], dtype=int32), mu={'params': {'Dense_0': {'bias': Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), 'kernel': Array([[[0., 0.],
        [0., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 0.],
        [0., 0.]]], dtype=float32)}}}, nu={'params': {'Dense_0': {'bias': Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), 'kernel': Array([[[0., 0.],
        [0., 0.],
        [0., 0.]],

       [[1., 1.],
        [1., 0.],
        [0., 0.]],

       [[1., 1.],
        [1., 0.],
        [0., 0.]],

       [[1., 1.],
        [1., 0.],
        [0., 0.]]], dtype=float32)}}}), EmptyState()))
Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=0/1)>
Array([[-1.4500695,  1.3251127],
       [ 0.       ,  0.       ],
       [ 0.       ,  0.       ],
       [ 0.       ,  0.       ]], dtype=float32)

Seems the state is not replicated? Also, could you check whether there is some problem in pmap?