alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.07k stars 357 forks source link

Using a bfloat16 causes Double Free Exception and Crash #881

Open samblouir opened 1 year ago

samblouir commented 1 year ago

Please describe the bug Hi, Using a bfloat16, whether by initializing an embedding layer or casting a float32 to bfloat16, causes a double free exception and crash. Sometimes it just prints out that there was a segmentation fault or that a worker died, without a verbose explanation - it can depend on the method selected for alpa's parallelize function. This happens to me using either Shard Parallel or Pipeshard Parallel, but it used to work in an earlier version of Alpa.

Please describe the expected behavior The data is used as a bfloat16 or cast to bfloat16 from float32 without issue.

System information and environment

To Reproduce Steps to reproduce the behavior: (I am starting this using a SLURM script)

  1. Run this py file with ray (It's a modified included test file)

  2. If you comment out line 52, it will work again.

x = x.astype(jnp.bfloat16) ## Comment this line out to fix the problem

  1. This line also causes an issue

    x = nn.Embed(256, 768, dtype=jnp.bfloat16)(inputs) # <-- This fails, too

Screenshots

-------------------------------------------------------------
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::trace: 3.10 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::jaxpr operations: 0.02 s
alpa.pipeline_parallel.stage_construction.cluster_layers_and_slice_mesh():  num_devices: 20,  num_stages: 20
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::stage construction: 0.05 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::apply grad: 0.05 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::shard stages: 28.51 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::launch meshes: 1.06 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::runtime emitter: 49.50 s
alpa.compile_executable.debug_compilation_time(): compile_pipeshard_executable::driver executable: 8.64 s
(MeshHostWorker pid=1468195, ip=172.16.130.110) free(): double free detected in tcache 2
(MeshHostWorker pid=1468195, ip=172.16.130.110) *** SIGABRT received at time=1676287900 on cpu 22 ***
(MeshHostWorker pid=1468195, ip=172.16.130.110) PC: @     0x7febf204337f  (unknown)  raise
(MeshHostWorker pid=1468195, ip=172.16.130.110)     @     0x7febf2b6dc20  309695680  (unknown)
(MeshHostWorker pid=1468195, ip=172.16.130.110)     @     0x7febf208d5ec  (unknown)  malloc_printerr
(MeshHostWorker pid=1468195, ip=172.16.130.110) [2023-02-13 06:31:40,764 E 1468195 1468195] logging.cc:361: *** SIGABRT received at time=1676287900 on cpu 22 ***
(MeshHostWorker pid=1468195, ip=172.16.130.110) [2023-02-13 06:31:40,764 E 1468195 1468195] logging.cc:361: PC: @     0x7febf204337f  (unknown)  raise
(MeshHostWorker pid=1468195, ip=172.16.130.110) [2023-02-13 06:31:40,764 E 1468195 1468195] logging.cc:361:     @     0x7febf2b6dc20  309695680  (unknown)
(MeshHostWorker pid=1468195, ip=172.16.130.110) [2023-02-13 06:31:40,764 E 1468195 1468195] logging.cc:361:     @     0x7febf208d5ec  (unknown)  malloc_printerr
(MeshHostWorker pid=1468195, ip=172.16.130.110) Fatal Python error: Aborted
(MeshHostWorker pid=1468195, ip=172.16.130.110) 
(MeshHostWorker pid=1468195, ip=172.16.130.110) Stack (most recent call first):
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/alpa/alpa/collective/worker_nccl_util_cupy.py", line 225 in xla_buffer_to_cupy
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/alpa/alpa/collective/worker_nccl_util_cupy.py", line 91 in recv_tile
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/alpa/alpa/collective/worker_nccl_util.py", line 11 in _switch_impl
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/alpa/alpa/collective/worker_nccl_util.py", line 27 in recv_tile
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/alpa/alpa/device_mesh.py", line 493 in recv_tile
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/.local/lib/python3.9/site-packages/ray/util/tracing/tracing_helper.py", line 466 in _resume_span
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/alpa/alpa/device_mesh.py", line 457 in run_resharding_recv_task
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/.local/lib/python3.9/site-packages/ray/util/tracing/tracing_helper.py", line 466 in _resume_span
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/alpa/alpa/pipeline_parallel/pipeshard_executable.py", line 543 in execute_on_worker
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/alpa/alpa/device_mesh.py", line 279 in run_executable
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/.local/lib/python3.9/site-packages/ray/util/tracing/tracing_helper.py", line 466 in _resume_span
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/.local/lib/python3.9/site-packages/ray/_private/function_manager.py", line 674 in actor_method_executor
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/.local/lib/python3.9/site-packages/ray/_private/worker.py", line 763 in main_loop
(MeshHostWorker pid=1468195, ip=172.16.130.110)   File "/home/sblouir/.local/lib/python3.9/site-packages/ray/_private/workers/default_worker.py", line 231 in <module>
2023-02-13 06:31:40,994 WARNING worker.py:1839 -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffffeeb376d8eb140bbb2046b3f001000000 Worker ID: a69bcad3d3518114e11fec2e1e8259b4cf28a5c5fa8f66dcbacf55a7 Node ID: b42a40ba7bcb1b78f72ba5de6b2cae8d5bf5652a51d18ae67b496891 Worker IP address: 172.16.130.110 Worker port: 10008 Worker PID: 1468195 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.

Code snippet to reproduce the problem

"""Utilities for testing."""
from functools import partial
import unittest
from collections.abc import Iterable
from typing import Callable, Optional

import jax
import jax.numpy as jnp
from jax.tree_util import tree_leaves
from jax.experimental.maps import FrozenDict as FrozenDictJax
import numpy as np
import optax
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict as FrozenDictFlax

import alpa
from alpa.api import init, shutdown, parallelize, value_and_grad
from alpa.model.bert_model import BertConfig, FlaxBertLayer
from alpa.model.model_util import FlaxBaseModelOutput, DynamicScale, TrainState
from alpa.parallel_method import PipeshardParallel
from alpa.pipeline_parallel.layer_construction import (AutoLayerOption,
                                                       ManualLayerOption)
from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary
from alpa.pipeline_parallel.stage_construction import (UniformStageOption,
                                                       StageOption)
from alpa.shard_parallel.auto_sharding import AutoShardingOption

from typing import Any, Callable, Optional, Union
from typing import NamedTuple, Optional, Tuple, Callable
from optax._src import base
from optax._src import clipping
from optax._src import combine
from optax._src import factorized
from optax._src import transform
import functools
import numpy as np
from optax._src import base
import optax
import chex
from optax._src import linear_algebra
from optax._src import base

ScalarOrSchedule = Union[float, base.Schedule]
MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]

class BasicModel(nn.Module):
    num_layers:int
        \#  This doesn't fix it, either
    \#def setup(self) -> None:
    \#  self.embeddings = nn.Embed(256, 768, dtype=jnp.bfloat16)

    @nn.compact
    def __call__(self, inputs=None, labels=None, attention_mask=None, loss_mask=None, target_input_ids=None, *args, **kwargs):
                \#  x = nn.Embed(256, 768, dtype=jnp.bfloat16)(inputs) \# <-- This fails, too
        x = nn.Embed(256, 768)(inputs)
        x = x.astype(jnp.bfloat16) ## Comment this line out to fix the problem
        x = nn.LayerNorm()(x)
        for _ in range(self.num_layers):
            x = nn.Dense(16)(x)
        x = nn.Dense(768)(x)

        y = nn.Embed(256, 768)(labels)
        y = nn.LayerNorm()(y)

        x = x.astype(jnp.float32)
        loss = (y-x)**2
        return jnp.sum(loss)

def get_bert_layer_train_state_and_step(batch_size, seq_len, num_layers,
                                        hidden_size, num_heads,
                                        clip_by_global_norm, use_dynamic_scale,
                                        add_manual_pipeline_marker):
    rngkey = jax.random.PRNGKey(0)
    inputs = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
    labels = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)
    attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)
    loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)
    target_input_ids = jax.random.randint(rngkey, (batch_size, seq_len,), 0, 384,)

    batch = {
        "inputs": inputs, 
        "attention_mask": attention_mask,
        "labels":labels, 
        "loss_mask":loss_mask, 
        "target_input_ids":target_input_ids, 
        }

    model = BasicModel(num_layers=num_layers,)
    params = model.init(rngkey, **batch,)

    tx = optax.adam(learning_rate=1e-2)

    if use_dynamic_scale:
        use_master_copy = False
        dynamic_scale = DynamicScale()
    else:
        dynamic_scale = None
        use_master_copy = False

    state = TrainState.create(apply_fn=model.apply,
                              params=params,
                              tx=tx,
                              dynamic_scale=dynamic_scale,
                              use_master_copy=use_master_copy)

    def train_step(state, batch):

        def loss_func(params):
            loss = state.apply_fn(params, **batch,)
            return loss

        dynamic_scale = state.dynamic_scale
        if dynamic_scale:
            grad_fn = dynamic_scale.value_and_grad(loss_func)
            dynamic_scale, is_fin, val, grads = grad_fn(state.params)
        else:
            grad_fn = value_and_grad(loss_func)
            val, grads = grad_fn(state.params)

        new_state = state.apply_gradients(grads=grads)

        if dynamic_scale:
            new_state = new_state.replace(
                opt_state=jax.tree_map(partial(jnp.where, is_fin),
                                       new_state.opt_state, state.opt_state),
                params=jax.tree_map(partial(jnp.where, is_fin),
                                    new_state.params, state.params),
                master_copy=jax.tree_map(partial(jnp.where,
                                                 is_fin), new_state.master_copy,
                                         state.master_copy),
                dynamic_scale=dynamic_scale)
        return new_state, val

    return state, batch, train_step

class PipelineBasicTest(unittest.TestCase):

    def setUp(self):
        init(cluster="ray")

    def tearDown(self):
        shutdown()

    def run_n_layer_bert(self,
                         num_layers=20,
                         batch_size=16,
                         seq_len=256,
                         hidden_size=512,
                         num_heads=512 // 64,
                         use_remat=False,
                         clip_by_global_norm=False,
                         use_dynamic_scale=False,
                         inject_train_step=None,
                         manual_pipeline_layer=True,
                         stage_option: Optional[StageOption] = None,
                         as_option: Optional[AutoShardingOption] = None,
                         do_numerical_test: bool = True):
        method = PipeshardParallel(
            num_micro_batches=4,
            default_auto_sharding_option=as_option or AutoShardingOption(),
            layer_option=ManualLayerOption(remat_layer=use_remat)
            if manual_pipeline_layer else AutoLayerOption(
                layer_num=num_layers,
                remat_mode="coarse_grained_remat" if use_remat else "none"),
            stage_option=stage_option or UniformStageOption())
        # method = alpa.ShardParallel()

        # Init model
        state, batch, train_step = get_bert_layer_train_state_and_step(
            batch_size=batch_size,
            seq_len=seq_len,
            num_layers=num_layers,
            hidden_size=hidden_size,
            num_heads=num_heads,
            clip_by_global_norm=clip_by_global_norm,
            use_dynamic_scale=use_dynamic_scale,
            add_manual_pipeline_marker=manual_pipeline_layer)
        if inject_train_step is not None:
            assert isinstance(inject_train_step, Callable)
            train_step = inject_train_step

        # Compile
        serial_train_step = train_step
        parallel_train_step = parallelize(train_step, method=method)
        executable = parallel_train_step.get_executable(state, batch)

        for _ in range(100):
            state, loss = parallel_train_step(state, batch)
            print(f"  loss: {loss}")

        hlo_text = executable.get_hlo_text()
        return hlo_text

if __name__ == "__main__":
    t = PipelineBasicTest()
    t.setUp()
    x = t.run_n_layer_bert(
        num_layers=alpa.get_global_num_devices(),
        manual_pipeline_layer=False,
        do_numerical_test=False,
    )
    print(f"*" * 60,)
    print(f"  x: {x}")
    print(f"*" * 60,)
    t.tearDown()

Additional information Add any other context about the problem here or include any logs that would be helpful to diagnose the problem. Please let me know if any more information can help. This doesn't seem to happen when casting a float32 to float16.

serach24 commented 1 year ago

I met the same problem when using cuda11.3 cudnn 8.2.0 and installing through the doc guide. When I ran python -m alpa.test_install, the second test (pipeshard) give exactly the same error. But I did not find anything related to bfloat16 in the test code.

Yanivmd commented 10 months ago

Bumping this hoping someone managed to resolve this. I'm hitting a similar problems on the tests with CUDA 11.8 (CUDNN 8).