Open psj1866 opened 6 months ago
The error message you posted all looks like normal INFO printouts - is there a more specific error message or stack trace, or did the program just crashed after these printouts?
In general, hardware issues like multi-gpu are more likely rooted in JAX, as Flax rarely directly touch lower level APIs. I'd also recommend trying some smaller, pure-JAX code (like from this multi-device guide or other JAX website sample code) to pinpoint the error to more specific lines.
Thanks for your reply!
Well, I tried simple code using pmap
, but not exactly sure that this is the right one for comparison. The code can be found in this link: https://github.com/google/flax/discussions/2121
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import jax_utils
import optax
from flax.training.train_state import TrainState
model = nn.Dense(1)
x = jnp.ones((jax.device_count(), 3))
params = model.init(jax.random.PRNGKey(0), x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(
apply_fn=model.apply, params=params, tx=tx,
)
state = jax_utils.replicate(state)
def loss_fn(state, x):
return (model.apply(state.params, x) ** 2.0).mean()
jax.pmap(loss_fn)(state, x)
Actually, the pmap
worked without any error in this code.
For your question is there a more specific error message or stack trace, or did the program just crashed after these printouts?, the program just crashed after this printout without further progress (I have waited more than 6 hours but nothing gained)
I found that since RTX 4090 serives, nvlink is not equipped. Would this be the reason for this error?
Thanks!
From your description it sounds like the program is blocked, instead of fail and exit immediately?
If blocked, it might be that the gpu devices (or their cpu hosts?) are out of sync. Maybe try to run those jax.lax.p.*
collectives in your pmap
ped pure-JAX function?
Another thing worth doing is adding a ton of prints in your code to bisect which line it is blocked at.
Also just FYI, jax.pmap
is outdated and JAX generally recommend using jax.shard_map
for per-device code. Or if you just want to try MNIST on Flax, you can use a version without pmap
at quickstart.
Thanks for your reply!
First, I'd like to apologize that my uploaded code was too long and did not specify the error point. By adding some print function in functions, I found that error occurs when the function train_step
is returning its output.
@partial(jax.pmap, axis_name="device", out_axes=(None, 0, 0, 0))
**def train_step(key, batch, variables, opt_state, metrics):**
params = variables['params']
def loss_fn(params, variables):
variables['params'] = params
logits, updates = module.apply(variables, batch['image'], training=True,
mutable='batch_stats', rngs={'dropout': key})
variables['updates'] = updates
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
return loss, (logits, variables)
# compute predictions and gradients
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, (logits, variables)), grads = grad_fn(params, variables)
print(0)
# sync gradients
grads = jax.lax.pmean(grads, "device")
print(1)
# update params
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
variables['params'] = params
print(2)
# sync batch stats
batch_stats = jax.lax.pmean(variables['batch_stats'], "device")
variables['batch_stats'] = batch_stats
print(3)
# compute metrics
metrics = metrics.update(logits=logits, labels=batch['label'])
logs = jax.lax.psum(metrics, "device").compute() # <== sync metrics
print(4)
return logs, variables, opt_state, metrics
num_epochs = 10
batch_size = 2
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
key, epoch_key = jax.random.split(key)
# Run an optimization step over a training batch
**print('a')**
**variables, opt_state = train_epoch(epoch_key, variables, opt_state, metrics0, train_ds, batch_size, epoch)
# Evaluate on the test set after each training epoch
print('b')**
batch = jax.tree_map(lambda x: x.reshape((DEVICE_COUNT, -1, *x.shape[1:])), test_ds)
test_loss, test_accuracy = eval_model(batch, variables, metrics0)
print(f' test epoch: {epoch}, loss: {test_loss:.8f}, accuracy: {test_accuracy * 100:.2f}')
The info message is as follows:
a
/tmp/ipykernel_1856689/328536960.py:131: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
batch = jax.tree_map(lambda x: x.reshape((DEVICE_COUNT, -1, *x.shape[1:])), batch)
0
1
2
3
4
cgroup-gpux4:1856689:1856816 [1] NCCL INFO Bootstrap : Using enp36s0f1:192.168.1.45<0>
cgroup-gpux4:1856689:1856816 [1] NCCL INFO NET/Plugin : dlerror=libnccl-net.so: cannot open shared object file: No such file or directory No plugin found (libnccl-net.so), using internal implementation
cgroup-gpux4:1856689:1856816 [0] NCCL INFO cudaDriverVersion 12030
NCCL version 2.20.5+cuda12.4
cgroup-gpux4:1856689:1856816 [0] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1856689:1856816 [1] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1856689:1856816 [2] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1856689:1856816 [3] NCCL INFO Comm config Split share set to 1
cgroup-gpux4:1856689:1856962 [3] NCCL INFO NET/IB : No device found.
cgroup-gpux4:1856689:1856962 [3] NCCL INFO NET/Socket : Using [0]enp36s0f1:192.168.1.45<0>
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Using network Socket
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Using network Socket
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Using network Socket
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Using non-device net plugin version 0
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Using network Socket
cgroup-gpux4:1856689:1856962 [3] NCCL INFO comm 0x7f72c8bd62d0 rank 3 nranks 4 cudaDev 3 nvmlDev 3 busId 61000 commId 0x438f56a3cde03111 - Init START
cgroup-gpux4:1856689:1856959 [0] NCCL INFO comm 0x7f72c8bbbbc0 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId 1000 commId 0x438f56a3cde03111 - Init START
cgroup-gpux4:1856689:1856960 [1] NCCL INFO comm 0x7f72c8bc3640 rank 1 nranks 4 cudaDev 1 nvmlDev 1 busId 2a000 commId 0x438f56a3cde03111 - Init START
cgroup-gpux4:1856689:1856961 [2] NCCL INFO comm 0x7f72c8bccd20 rank 2 nranks 4 cudaDev 2 nvmlDev 2 busId 41000 commId 0x438f56a3cde03111 - Init START
cgroup-gpux4:1856689:1856961 [2] NCCL INFO NVLS multicast support is not available on dev 2
cgroup-gpux4:1856689:1856959 [0] NCCL INFO NVLS multicast support is not available on dev 0
cgroup-gpux4:1856689:1856960 [1] NCCL INFO NVLS multicast support is not available on dev 1
cgroup-gpux4:1856689:1856962 [3] NCCL INFO NVLS multicast support is not available on dev 3
cgroup-gpux4:1856689:1856961 [2] NCCL INFO comm 0x7f72c8bccd20 rank 2 nRanks 4 nNodes 1 localRanks 4 localRank 2 MNNVL 0
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Trees [0] 3/-1/-1->2->1 [1] 3/-1/-1->2->1 [2] 3/-1/-1->2->1 [3] 3/-1/-1->2->1
cgroup-gpux4:1856689:1856960 [1] NCCL INFO comm 0x7f72c8bc3640 rank 1 nRanks 4 nNodes 1 localRanks 4 localRank 1 MNNVL 0
cgroup-gpux4:1856689:1856961 [2] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1856689:1856962 [3] NCCL INFO comm 0x7f72c8bd62d0 rank 3 nRanks 4 nNodes 1 localRanks 4 localRank 3 MNNVL 0
cgroup-gpux4:1856689:1856959 [0] NCCL INFO comm 0x7f72c8bbbbc0 rank 0 nRanks 4 nNodes 1 localRanks 4 localRank 0 MNNVL 0
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Trees [0] 2/-1/-1->1->0 [1] 2/-1/-1->1->0 [2] 2/-1/-1->1->0 [3] 2/-1/-1->1->0
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Trees [0] -1/-1/-1->3->2 [1] -1/-1/-1->3->2 [2] -1/-1/-1->3->2 [3] -1/-1/-1->3->2
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 00/04 : 0 1 2 3
cgroup-gpux4:1856689:1856960 [1] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 01/04 : 0 1 2 3
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 02/04 : 0 1 2 3
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 03/04 : 0 1 2 3
cgroup-gpux4:1856689:1856962 [3] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] 1/-1/-1->0->-1 [3] 1/-1/-1->0->-1
cgroup-gpux4:1856689:1856959 [0] NCCL INFO P2P Chunksize set to 131072
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 00/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 00/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 01/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 01/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 02/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 02/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 00/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 03/0 : 2[2] -> 3[3] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 03/0 : 1[1] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 02/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 01/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Channel 03/0 : 0[0] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 02/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 03/0 : 3[3] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Connected all rings
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Connected all rings
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Connected all rings
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Connected all rings
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 00/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 01/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 02/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Channel 03/0 : 3[3] -> 2[2] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 00/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 01/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 02/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 02/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Channel 03/0 : 1[1] -> 0[0] via P2P/direct pointer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Channel 03/0 : 2[2] -> 1[1] via P2P/direct pointer
cgroup-gpux4:1856689:1856959 [0] NCCL INFO Connected all trees
cgroup-gpux4:1856689:1856959 [0] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1856689:1856959 [0] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1856689:1856960 [1] NCCL INFO Connected all trees
cgroup-gpux4:1856689:1856960 [1] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1856689:1856960 [1] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1856689:1856962 [3] NCCL INFO Connected all trees
cgroup-gpux4:1856689:1856962 [3] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1856689:1856962 [3] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO Connected all trees
cgroup-gpux4:1856689:1856961 [2] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
cgroup-gpux4:1856689:1856961 [2] NCCL INFO 4 coll channels, 0 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
cgroup-gpux4:1856689:1856961 [2] NCCL INFO comm 0x7f72c8bccd20 rank 2 nranks 4 cudaDev 2 nvmlDev 2 busId 41000 commId 0x438f56a3cde03111 - Init COMPLETE
cgroup-gpux4:1856689:1856959 [0] NCCL INFO comm 0x7f72c8bbbbc0 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId 1000 commId 0x438f56a3cde03111 - Init COMPLETE
cgroup-gpux4:1856689:1856960 [1] NCCL INFO comm 0x7f72c8bc3640 rank 1 nranks 4 cudaDev 1 nvmlDev 1 busId 2a000 commId 0x438f56a3cde03111 - Init COMPLETE
cgroup-gpux4:1856689:1856962 [3] NCCL INFO comm 0x7f72c8bd62d0 rank 3 nranks 4 cudaDev 3 nvmlDev 3 busId 61000 commId 0x438f56a3cde03111 - Init COMPLETE
/tmp/ipykernel_1856689/328536960.py:131: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
batch = jax.tree_map(lambda x: x.reshape((DEVICE_COUNT, -1, *x.shape[1:])), batch)
0
1
2
3
4
... and no further progress.
Thanks!
Do you have any printout in train_epoch
function to pinpoint the line of blockage?
We would really benefit from a smaller code that can repro the problem and narrow down our search. If it seems to be from train_step
, maybe calling it directly with some fake input?
Thanks for your comment! Following your advice, I would like to start with the very fundamental problem: replication of the train state.
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=2)(x)
return x
x = np.ones((jax.device_count(), 3))
y = np.zeros((jax.device_count(), 3))
print(x, y)
model = MLP()
params = model.init(jax.random.PRNGKey(0), x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
print(state)
print('###########################################################################')
state_rep = jax_utils.replicate(state)
print(state_rep)
def loss_fn(state, x, y):
print((model.apply(state.params, x)))
return (model.apply(state.params, x))
jax.pmap(loss_fn)(state_rep, x, y)
Output:
[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]] [[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
TrainState(step=0, apply_fn=<bound method Module.apply of MLP()>, params={'params': {'Dense_0': {'kernel': Array([[-0.3194899 , 0.9700081 ],
[-1.1898965 , -0.02842531],
[ 0.05931681, 0.38353 ]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f4190147130>, update=<function chain.<locals>.update_fn at 0x7f4190198670>), opt_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu={'params': {'Dense_0': {'bias': Array([0., 0.], dtype=float32), 'kernel': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32)}}}, nu={'params': {'Dense_0': {'bias': Array([0., 0.], dtype=float32), 'kernel': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32)}}}), EmptyState()))
###########################################################################
TrainState(step=Array([ 0, 1065353216, 1065353216, 1065353216], dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of MLP()>, params={'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32), 'kernel': Array([[[-0.3194899 , 0.9700081 ],
[-1.1898965 , -0.02842531],
[ 0.05931681, 0.38353 ]],
[[ 0. , 0. ],
[ 0. , 0. ],
[ 0. , 0. ]],
[[ 0. , 0. ],
[ 0. , 0. ],
[ 0. , 0. ]],
[[ 0. , 0. ],
[ 0. , 0. ],
[ 0. , 0. ]]], dtype=float32)}}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f4190147130>, update=<function chain.<locals>.update_fn at 0x7f4190198670>), opt_state=(ScaleByAdamState(count=Array([ 0, 1065353216, 1065353216, 1065353216], dtype=int32), mu={'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32), 'kernel': Array([[[0., 0.],
[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.],
[0., 0.]]], dtype=float32)}}}, nu={'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32), 'kernel': Array([[[0., 0.],
[0., 0.],
[0., 0.]],
[[1., 1.],
[1., 0.],
[0., 0.]],
[[1., 1.],
[1., 0.],
[0., 0.]],
[[1., 1.],
[1., 0.],
[0., 0.]]], dtype=float32)}}}), EmptyState()))
Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=0/1)>
Array([[-1.4500695, 1.3251127],
[ 0. , 0. ],
[ 0. , 0. ],
[ 0. , 0. ]], dtype=float32)
Seems the state is not replicated? Also, could you check whether there is some problem in pmap
?
Dear FLAX community,
System information
OS Platform and Distribution: Ubuntu 22.04.3 LTS
Flax, jax, jaxlib versions : Flax: 0.8.1 / jax: 0.4.27 / jaxlib: 0.4.27+cuda12.cudnn89
Python version: 3.10
GPU/TPU model and memory & CUDA version (if applicable):
Problem you have encountered:
As shown in the image above, my server computer is equipped with 4 RTX 4090 GPUs. I tried to run batch-training through multi-gpu, but it didn't work with error message below. For me, it seems like the problem comes from the NVIDIA GPU, not from Python.
What you expected to happen:
I want to use multi-gpu for batch-training in FLAX in my server computer environment. How can I fix my code or re-build the environment? (I am quite new to Linux...)
Logs, error messages, etc:
Error message is as follows:
Steps to reproduce:
I followed this benchmark code: (https://colab.research.google.com/drive/1hXns2b6T8T393zSrKCSoUktye1YlSe8U?usp=sharing#scrollTo=oKcRiQ89xQkF) and fixed several issues. The code used for my server is as follows:
(In this paragraph, the error message appears)
It seems that there are lots of people suffering from multi-gpu environment with RTX 4090??
Thanks for reading!