alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.05k stars 353 forks source link

Invalid argument passed in nccl_all_reduce_thunk.cc to ncclReduceScatter and ncclAllReduce with bfloat16 #883

Open samblouir opened 1 year ago

samblouir commented 1 year ago

Please describe the bug When creating a toy model using ShardParallel/Zero2/PipeshardParallel and bfloat16, the first step works, but subsequent steps crash citing an error in the arguments to nccl_all_reduce_thunk.cc

The same code works as expected using jnp.float32 or jnp.float16.

Please describe the expected behavior Training continues without crashing or hanging

System information and environment

To Reproduce Steps to reproduce the behavior:

  1. Create input ids for a language model,
  2. Cast the embeddings to a bfloat16
  3. Try to train for more than one step

Screenshots If applicable, add screenshots to help explain your problem.

2023-02-19 20:44:48,654 INFO worker.py:1342 -- Connecting to existing Ray cluster at address: 172.16.130.119:6379...
2023-02-19 20:44:48,662 INFO worker.py:1528 -- Connected to Ray cluster.
-------------------- Layer slicing stats --------------------
layer_num: 2
 - Number of Jaxpr eqns in each stage:
Layer 0: #eqns=53, flop=0.005 TFlop, #heavy_ops=3
Layer 1: #eqns=17, flop=0.000 TFlop, #heavy_ops=4
 - Invars of each stage:
Layer 0 has inputs:
Layer 1 has inputs:
cr (16, 256, 16) from layer 0
cq (16, 256, 768) from layer 0
-------------------------------------------------------------
compile_pipeshard_executable::trace: 1.88 s
compile_pipeshard_executable::jaxpr operations: 0.01 s
compile_pipeshard_executable::stage construction: 0.02 s
compile_pipeshard_executable::apply grad: 0.01 s
compile_pipeshard_executable::shard stages: 13.06 s
compile_pipeshard_executable::launch meshes: 0.48 s
compile_pipeshard_executable::runtime emitter: 15.57 s
compile_pipeshard_executable::driver executable: 0.81 s
  loss: DistributedArray(sharding_spec=ShardingSpec((), (Replicated(replicas=2),)), value=5190279.0)
(MeshHostWorker pid=1860383) 2023-02-19 20:45:30.941890: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2156] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:458: NCCL operation ncclReduceScatter(send_buffer, recv_buffer, recv_count, dtype, reduce_op, comm, gpu_stream) failed: invalid argument
(MeshHostWorker pid=1860383) 2023-02-19 20:45:30.941928: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2156] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:458: NCCL operation ncclReduceScatter(send_buffer, recv_buffer, recv_count, dtype, reduce_op, comm, gpu_stream) failed: invalid argument

ShardParallel with XLA_GPU_SKIP_ALLREDUCE = 1 causes a similar error for ncclAllReduce HostWorker pid=1928057) 2023-02-19 21:24:48.552272: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2156] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:120: NCCL operation ncclAllReduce(send_buffer, recv_buffer, element_count, dtype, reduce_op, comm, gpu_stream) failed: invalid argument

If I do not pass in the updated state during training with PipeshardParallel, then this error comes out:

compile_pipeshard_executable::trace: 1.88 s
compile_pipeshard_executable::jaxpr operations: 0.01 s
compile_pipeshard_executable::stage construction: 0.02 s
compile_pipeshard_executable::apply grad: 0.01 s
compile_pipeshard_executable::shard stages: 12.41 s
compile_pipeshard_executable::launch meshes: 0.49 s
compile_pipeshard_executable::runtime emitter: 16.26 s
compile_pipeshard_executable::driver executable: 0.90 s
  loss: DistributedArray(sharding_spec=ShardingSpec((), (Replicated(replicas=2),)), value=5190279.0)
Traceback (most recent call last):
  File "/scratch/sblouir/mount/code/f22/sAlpa2.py", line 212, in <module>
    x = t.run_n_layer_bert(
  File "/scratch/sblouir/mount/code/f22/sAlpa2.py", line 203, in run_n_layer_bert
    _, loss = parallel_train_step(state, batch)
  File "/home/sblouir/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/scratch/sblouir/alpa/alpa/api.py", line 122, in __call__
    out = executable.launch_on_driver(*args_flat)
  File "/scratch/sblouir/alpa/alpa/pipeline_parallel/pipeshard_executable.py", line 165, in launch_on_driver
    tmp_bufs = physical_mesh.shard_args_to_bufs(
  File "/scratch/sblouir/alpa/alpa/device_mesh.py", line 1297, in shard_args_to_bufs
    ref = shard_arg_handlers[type(arg)](arg, self, indices)[0]
  File "/scratch/sblouir/alpa/alpa/device_mesh.py", line 2474, in _shard_device_array
    return _shard_array(np.asarray(array), device_mesh, indices, num_batch,
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: DeviceArray has been deleted.

Code snippet to reproduce the problem


"""Utilities for testing."""
from functools import partial
import unittest
from collections.abc import Iterable
from typing import Callable, Optional

import jax
import jax.numpy as jnp
from jax.tree_util import tree_leaves
from jax.experimental.maps import FrozenDict as FrozenDictJax
import numpy as np
import optax
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict as FrozenDictFlax

import alpa
from alpa.api import init, shutdown, parallelize, value_and_grad
from alpa.model.bert_model import BertConfig, FlaxBertLayer
from alpa.model.model_util import FlaxBaseModelOutput, DynamicScale, TrainState
from alpa.parallel_method import PipeshardParallel
from alpa.pipeline_parallel.layer_construction import (AutoLayerOption, ManualLayerOption)
from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary
from alpa.pipeline_parallel.stage_construction import (UniformStageOption, StageOption)
from alpa.shard_parallel.auto_sharding import AutoShardingOption

from typing import Any, Callable, Optional, Union
from typing import NamedTuple, Optional, Tuple, Callable
from optax._src import base
from optax._src import clipping
from optax._src import combine
from optax._src import factorized
from optax._src import transform
import functools
import numpy as np
from optax._src import base
import optax
import chex
from optax._src import linear_algebra
from optax._src import base

ScalarOrSchedule = Union[float, base.Schedule]
MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]

class BasicModel(nn.Module):
    num_layers:int
    vocab_size=384
    hidden_size=768
    embedding_dimensions=768

    dtype=jnp.bfloat16
    # dtype=jnp.float32
    # dtype=jnp.float16

    @nn.compact
    def __call__(self, inputs=None, labels=None, attention_mask=None, loss_mask=None, target_input_ids=None, train=False,*args, **kwargs):

        embedding_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0)
        embeddings_layer = nn.Embed(self.vocab_size, self.embedding_dimensions, name="embeddings", dtype=self.dtype, embedding_init=embedding_init,)
        embeddings = embeddings_layer(inputs).astype(self.dtype)
        x = embeddings
        x = nn.LayerNorm()(x)

        target_embeddings = embeddings_layer(target_input_ids).astype(self.dtype)
        target_embeddings = nn.LayerNorm()(target_embeddings)

        for _ in range(self.num_layers):
            x = nn.Dense(16)(x)
        x = nn.Dense(target_embeddings.shape[-1])(x)

        x = x.astype(jnp.float32)
        loss = (target_embeddings - x)**2
        return jnp.sum(loss)

def get_bert_layer_train_state_and_step(batch_size, seq_len, num_layers,
                                        hidden_size, num_heads,
                                        clip_by_global_norm, use_dynamic_scale,
                                        add_manual_pipeline_marker):
    rngkey = jax.random.PRNGKey(0)
    inputs = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
    labels = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
    attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)
    loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)
    target_input_ids = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)

    batch = {
        "inputs": inputs, 
        "attention_mask": attention_mask,
        "labels":labels, 
        "loss_mask":loss_mask, 
        "target_input_ids":target_input_ids, 
        }

    model = BasicModel(num_layers=num_layers,)
    params = model.init(rngkey, **batch,)

    tx = optax.adam(learning_rate=1e-2)

    if use_dynamic_scale:
        use_master_copy = False
        dynamic_scale = DynamicScale()
    else:
        dynamic_scale = None
        use_master_copy = False

    state = TrainState.create(apply_fn=model.apply,
                              params=params,
                              tx=tx,
                              dynamic_scale=dynamic_scale,
                              use_master_copy=use_master_copy)

    def train_step(state, batch):

        def loss_func(params):
            loss = state.apply_fn(params, **batch,)
            return loss

        dynamic_scale = state.dynamic_scale
        if dynamic_scale:
            grad_fn = dynamic_scale.value_and_grad(loss_func)
            dynamic_scale, is_fin, val, grads = grad_fn(state.params)
        else:
            grad_fn = value_and_grad(loss_func)
            val, grads = grad_fn(state.params)

        new_state = state.apply_gradients(grads=grads)

        if dynamic_scale:
            new_state = new_state.replace(
                opt_state=jax.tree_map(partial(jnp.where, is_fin),
                                       new_state.opt_state, state.opt_state),
                params=jax.tree_map(partial(jnp.where, is_fin),
                                    new_state.params, state.params),
                master_copy=jax.tree_map(partial(jnp.where,
                                                 is_fin), new_state.master_copy,
                                         state.master_copy),
                dynamic_scale=dynamic_scale)
        return new_state, val

    return state, batch, train_step

class PipelineBasicTest(unittest.TestCase):

    def setUp(self):
        init(cluster="ray")

    def tearDown(self):
        shutdown()

    def run_n_layer_bert(self,
                         num_layers=20,
                         batch_size=16,
                         seq_len=256,
                         hidden_size=512,
                         num_heads=512 // 64,
                         use_remat=False,
                         clip_by_global_norm=False,
                         use_dynamic_scale=False,
                         inject_train_step=None,
                         manual_pipeline_layer=True,
                         stage_option: Optional[StageOption] = None,
                         as_option: Optional[AutoShardingOption] = None,
                         do_numerical_test: bool = True):
        method = PipeshardParallel(num_micro_batches=1,)

        # Init model
        state, batch, train_step = get_bert_layer_train_state_and_step(
            batch_size=batch_size,
            seq_len=seq_len,
            num_layers=num_layers,
            hidden_size=hidden_size,
            num_heads=num_heads,
            clip_by_global_norm=clip_by_global_norm,
            use_dynamic_scale=use_dynamic_scale,
            add_manual_pipeline_marker=manual_pipeline_layer)
        if inject_train_step is not None:
            assert isinstance(inject_train_step, Callable)
            train_step = inject_train_step

        # Compile
        serial_train_step = train_step
        parallel_train_step = parallelize(train_step, method=method)
        executable = parallel_train_step.get_executable(state, batch)

        for _ in range(100):
            state, loss = parallel_train_step(state, batch)
            print(f"  loss: {loss}")

        hlo_text = executable.get_hlo_text()
        return hlo_text

if __name__ == "__main__":
    t = PipelineBasicTest()
    t.setUp()
    x = t.run_n_layer_bert(
        num_layers=alpa.get_global_num_devices(),
        manual_pipeline_layer=False,
        do_numerical_test=False,
    )
    print(f"*" * 60,)
    # print(f"  x: {x}")
    print(f"*" * 60,)
    t.tearDown()

Additional information Please let me know if I can add anything to help.

samblouir commented 1 year ago

Edit: I have spoken too soon about the workaround.

There seems to be an issue with the buffer_dict when using bfloat16. There are uuids missing from it, even though they are in input_uuids.