Please describe the bug
When creating a toy model using ShardParallel/Zero2/PipeshardParallel and bfloat16, the first step works, but subsequent steps crash citing an error in the arguments to nccl_all_reduce_thunk.cc
The same code works as expected using jnp.float32 or jnp.float16.
Please describe the expected behavior
Training continues without crashing or hanging
System information and environment
OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker):
RHEL 8.5, Docker image: cuda_11.8.0-cudnn8-devel-ubuntu22.04
Python version:
3.9.6
CUDA version:
11.8 (Also happens with Cuda 11.5.1-75)
Screenshots
If applicable, add screenshots to help explain your problem.
2023-02-19 20:44:48,654 INFO worker.py:1342 -- Connecting to existing Ray cluster at address: 172.16.130.119:6379...
2023-02-19 20:44:48,662 INFO worker.py:1528 -- Connected to Ray cluster.
-------------------- Layer slicing stats --------------------
layer_num: 2
- Number of Jaxpr eqns in each stage:
Layer 0: #eqns=53, flop=0.005 TFlop, #heavy_ops=3
Layer 1: #eqns=17, flop=0.000 TFlop, #heavy_ops=4
- Invars of each stage:
Layer 0 has inputs:
Layer 1 has inputs:
cr (16, 256, 16) from layer 0
cq (16, 256, 768) from layer 0
-------------------------------------------------------------
compile_pipeshard_executable::trace: 1.88 s
compile_pipeshard_executable::jaxpr operations: 0.01 s
compile_pipeshard_executable::stage construction: 0.02 s
compile_pipeshard_executable::apply grad: 0.01 s
compile_pipeshard_executable::shard stages: 13.06 s
compile_pipeshard_executable::launch meshes: 0.48 s
compile_pipeshard_executable::runtime emitter: 15.57 s
compile_pipeshard_executable::driver executable: 0.81 s
loss: DistributedArray(sharding_spec=ShardingSpec((), (Replicated(replicas=2),)), value=5190279.0)
(MeshHostWorker pid=1860383) 2023-02-19 20:45:30.941890: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2156] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:458: NCCL operation ncclReduceScatter(send_buffer, recv_buffer, recv_count, dtype, reduce_op, comm, gpu_stream) failed: invalid argument
(MeshHostWorker pid=1860383) 2023-02-19 20:45:30.941928: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2156] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:458: NCCL operation ncclReduceScatter(send_buffer, recv_buffer, recv_count, dtype, reduce_op, comm, gpu_stream) failed: invalid argument
ShardParallel with XLA_GPU_SKIP_ALLREDUCE = 1 causes a similar error for ncclAllReduce
HostWorker pid=1928057) 2023-02-19 21:24:48.552272: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2156] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:120: NCCL operation ncclAllReduce(send_buffer, recv_buffer, element_count, dtype, reduce_op, comm, gpu_stream) failed: invalid argument
If I do not pass in the updated state during training with PipeshardParallel, then this error comes out:
compile_pipeshard_executable::trace: 1.88 s
compile_pipeshard_executable::jaxpr operations: 0.01 s
compile_pipeshard_executable::stage construction: 0.02 s
compile_pipeshard_executable::apply grad: 0.01 s
compile_pipeshard_executable::shard stages: 12.41 s
compile_pipeshard_executable::launch meshes: 0.49 s
compile_pipeshard_executable::runtime emitter: 16.26 s
compile_pipeshard_executable::driver executable: 0.90 s
loss: DistributedArray(sharding_spec=ShardingSpec((), (Replicated(replicas=2),)), value=5190279.0)
Traceback (most recent call last):
File "/scratch/sblouir/mount/code/f22/sAlpa2.py", line 212, in <module>
x = t.run_n_layer_bert(
File "/scratch/sblouir/mount/code/f22/sAlpa2.py", line 203, in run_n_layer_bert
_, loss = parallel_train_step(state, batch)
File "/home/sblouir/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/scratch/sblouir/alpa/alpa/api.py", line 122, in __call__
out = executable.launch_on_driver(*args_flat)
File "/scratch/sblouir/alpa/alpa/pipeline_parallel/pipeshard_executable.py", line 165, in launch_on_driver
tmp_bufs = physical_mesh.shard_args_to_bufs(
File "/scratch/sblouir/alpa/alpa/device_mesh.py", line 1297, in shard_args_to_bufs
ref = shard_arg_handlers[type(arg)](arg, self, indices)[0]
File "/scratch/sblouir/alpa/alpa/device_mesh.py", line 2474, in _shard_device_array
return _shard_array(np.asarray(array), device_mesh, indices, num_batch,
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: DeviceArray has been deleted.
Code snippet to reproduce the problem
"""Utilities for testing."""
from functools import partial
import unittest
from collections.abc import Iterable
from typing import Callable, Optional
import jax
import jax.numpy as jnp
from jax.tree_util import tree_leaves
from jax.experimental.maps import FrozenDict as FrozenDictJax
import numpy as np
import optax
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict as FrozenDictFlax
import alpa
from alpa.api import init, shutdown, parallelize, value_and_grad
from alpa.model.bert_model import BertConfig, FlaxBertLayer
from alpa.model.model_util import FlaxBaseModelOutput, DynamicScale, TrainState
from alpa.parallel_method import PipeshardParallel
from alpa.pipeline_parallel.layer_construction import (AutoLayerOption, ManualLayerOption)
from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary
from alpa.pipeline_parallel.stage_construction import (UniformStageOption, StageOption)
from alpa.shard_parallel.auto_sharding import AutoShardingOption
from typing import Any, Callable, Optional, Union
from typing import NamedTuple, Optional, Tuple, Callable
from optax._src import base
from optax._src import clipping
from optax._src import combine
from optax._src import factorized
from optax._src import transform
import functools
import numpy as np
from optax._src import base
import optax
import chex
from optax._src import linear_algebra
from optax._src import base
ScalarOrSchedule = Union[float, base.Schedule]
MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]
class BasicModel(nn.Module):
num_layers:int
vocab_size=384
hidden_size=768
embedding_dimensions=768
dtype=jnp.bfloat16
# dtype=jnp.float32
# dtype=jnp.float16
@nn.compact
def __call__(self, inputs=None, labels=None, attention_mask=None, loss_mask=None, target_input_ids=None, train=False,*args, **kwargs):
embedding_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0)
embeddings_layer = nn.Embed(self.vocab_size, self.embedding_dimensions, name="embeddings", dtype=self.dtype, embedding_init=embedding_init,)
embeddings = embeddings_layer(inputs).astype(self.dtype)
x = embeddings
x = nn.LayerNorm()(x)
target_embeddings = embeddings_layer(target_input_ids).astype(self.dtype)
target_embeddings = nn.LayerNorm()(target_embeddings)
for _ in range(self.num_layers):
x = nn.Dense(16)(x)
x = nn.Dense(target_embeddings.shape[-1])(x)
x = x.astype(jnp.float32)
loss = (target_embeddings - x)**2
return jnp.sum(loss)
def get_bert_layer_train_state_and_step(batch_size, seq_len, num_layers,
hidden_size, num_heads,
clip_by_global_norm, use_dynamic_scale,
add_manual_pipeline_marker):
rngkey = jax.random.PRNGKey(0)
inputs = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
labels = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)
loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)
target_input_ids = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
batch = {
"inputs": inputs,
"attention_mask": attention_mask,
"labels":labels,
"loss_mask":loss_mask,
"target_input_ids":target_input_ids,
}
model = BasicModel(num_layers=num_layers,)
params = model.init(rngkey, **batch,)
tx = optax.adam(learning_rate=1e-2)
if use_dynamic_scale:
use_master_copy = False
dynamic_scale = DynamicScale()
else:
dynamic_scale = None
use_master_copy = False
state = TrainState.create(apply_fn=model.apply,
params=params,
tx=tx,
dynamic_scale=dynamic_scale,
use_master_copy=use_master_copy)
def train_step(state, batch):
def loss_func(params):
loss = state.apply_fn(params, **batch,)
return loss
dynamic_scale = state.dynamic_scale
if dynamic_scale:
grad_fn = dynamic_scale.value_and_grad(loss_func)
dynamic_scale, is_fin, val, grads = grad_fn(state.params)
else:
grad_fn = value_and_grad(loss_func)
val, grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
if dynamic_scale:
new_state = new_state.replace(
opt_state=jax.tree_map(partial(jnp.where, is_fin),
new_state.opt_state, state.opt_state),
params=jax.tree_map(partial(jnp.where, is_fin),
new_state.params, state.params),
master_copy=jax.tree_map(partial(jnp.where,
is_fin), new_state.master_copy,
state.master_copy),
dynamic_scale=dynamic_scale)
return new_state, val
return state, batch, train_step
class PipelineBasicTest(unittest.TestCase):
def setUp(self):
init(cluster="ray")
def tearDown(self):
shutdown()
def run_n_layer_bert(self,
num_layers=20,
batch_size=16,
seq_len=256,
hidden_size=512,
num_heads=512 // 64,
use_remat=False,
clip_by_global_norm=False,
use_dynamic_scale=False,
inject_train_step=None,
manual_pipeline_layer=True,
stage_option: Optional[StageOption] = None,
as_option: Optional[AutoShardingOption] = None,
do_numerical_test: bool = True):
method = PipeshardParallel(num_micro_batches=1,)
# Init model
state, batch, train_step = get_bert_layer_train_state_and_step(
batch_size=batch_size,
seq_len=seq_len,
num_layers=num_layers,
hidden_size=hidden_size,
num_heads=num_heads,
clip_by_global_norm=clip_by_global_norm,
use_dynamic_scale=use_dynamic_scale,
add_manual_pipeline_marker=manual_pipeline_layer)
if inject_train_step is not None:
assert isinstance(inject_train_step, Callable)
train_step = inject_train_step
# Compile
serial_train_step = train_step
parallel_train_step = parallelize(train_step, method=method)
executable = parallel_train_step.get_executable(state, batch)
for _ in range(100):
state, loss = parallel_train_step(state, batch)
print(f" loss: {loss}")
hlo_text = executable.get_hlo_text()
return hlo_text
if __name__ == "__main__":
t = PipelineBasicTest()
t.setUp()
x = t.run_n_layer_bert(
num_layers=alpa.get_global_num_devices(),
manual_pipeline_layer=False,
do_numerical_test=False,
)
print(f"*" * 60,)
# print(f" x: {x}")
print(f"*" * 60,)
t.tearDown()
Additional information
Please let me know if I can add anything to help.
Please describe the bug When creating a toy model using ShardParallel/Zero2/PipeshardParallel and bfloat16, the first step works, but subsequent steps crash citing an error in the arguments to nccl_all_reduce_thunk.cc
The same code works as expected using jnp.float32 or jnp.float16.
Please describe the expected behavior Training continues without crashing or hanging
System information and environment
To Reproduce Steps to reproduce the behavior:
Screenshots If applicable, add screenshots to help explain your problem.
ShardParallel with XLA_GPU_SKIP_ALLREDUCE = 1 causes a similar error for ncclAllReduce
HostWorker pid=1928057) 2023-02-19 21:24:48.552272: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2156] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:120: NCCL operation ncclAllReduce(send_buffer, recv_buffer, element_count, dtype, reduce_op, comm, gpu_stream) failed: invalid argument
If I do not pass in the updated state during training with PipeshardParallel, then this error comes out:
Code snippet to reproduce the problem
Additional information Please let me know if I can add anything to help.