microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.37k stars 4.1k forks source link

[BUG] The NCCL timed out while using the zero3 model. How can I solve this problem? #5066

Open awzhgw opened 9 months ago

awzhgw commented 9 months ago

The NCCL timed out while using the zero3 model. How can I solve this problem?

I inherited the large model Mixtral 7BX8 and utilized the Llama architecture, augmenting it with multi-modal capabilities for video and audio.

The architecture of my model is as follows:

LlavaMixtralForCausalLM(
  (model): LlavaMixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): MixtralDecoderLayer(
        (self_attn): MixtralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBLockSparseTop2MLP(
              (w1): Linear(in_features=4096, out_features=14336, bias=False)
              (w2): Linear(in_features=14336, out_features=4096, bias=False)
              (w3): Linear(in_features=4096, out_features=14336, bias=False)
              (act_fn): SiLU()
            )
          )
        )
        (input_layernorm): MixtralRMSNorm()
        (post_attention_layernorm): MixtralRMSNorm()
      )
    )
    (norm): MixtralRMSNorm()
    (image_tower): CLIPVisionTower(
      (image_tower): CLIPVisionModel(
        (vision_model): CLIPVisionTransformer(
          (embeddings): CLIPVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
            (position_embedding): Embedding(577, 1024)
          )
          (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (encoder): CLIPEncoder(
            (layers): ModuleList(
              (0-23): 24 x CLIPEncoderLayer(
                (self_attn): CLIPAttention(
                  (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                )
                (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (mlp): CLIPMLP(
                  (activation_fn): QuickGELUActivation()
                  (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                  (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                )
                (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              )
            )
          )
          (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (video_tower): LanguageBindVideoTower(
      (video_tower): CLIPVisionTransformer(
        (embeddings): CLIPVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
          (position_embedding): Embedding(257, 1024)
        )
        (patch_dropout): PatchDropout()
        (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder): CLIPEncoder(
          (layers): ModuleList(
            (0-23): 24 x CLIPEncoderLayer(
              (self_attn): CLIPAttention(
                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
              )
              (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): CLIPMLP(
                (activation_fn): GELUActivation()
                (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                (fc2): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (temporal_attn): CLIPAttention(
                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
              )
              (temporal_layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
          )
        )
        (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )
    (mm_projector): build_projector(
      (image_spatial_proj): Sequential(
        (0): Linear(in_features=1024, out_features=4096, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=4096, out_features=4096, bias=True)
      )
      (video_patch_proj): Linear(in_features=1024, out_features=4096, bias=True)
    )
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

After initializing the model, I have already called deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) print('model z3_leaf_model is ',deepspeed.utils.get_z3_leaf_modules(model))

The printed result is as follows.:

model z3_leaf_model is  [MixtralSparseMoeBlock(
  (gate): Linear(in_features=4096, out_features=8, bias=False)
  (experts): ModuleList(
    (0-7): 8 x MixtralBLockSparseTop2MLP(
      (w1): Linear(in_features=4096, out_features=14336, bias=False)
      (w2): Linear(in_features=14336, out_features=4096, bias=False)
      (w3): Linear(in_features=4096, out_features=14336, bias=False)
      (act_fn): SiLU()
    )
  )
)]

This proves that the z3_leaf_model has been set up correctly.

my deepspeed version is deepspeed master branch
  1. The training process is as follows: Scenario 1: When I use zero3 for deepspeed training, if the training data source only contains images, there are no issues, and training can proceed safely.

    Scenario 2: When I use zero3 for deepspeed training, if the training data source contains both images and videos, it will get stuck after 270 steps, with an ongoing NCCL timeout.

The error message is as follows.

{'loss': 6.8843, 'learning_rate': 0.0009838432886246189, 'epoch': 0.03}
  3%|▎         | 270/9847 [06:18<3:39:12,  1.37s/it]Invalidate trace cache @ step 1: expected module 15, but got module 313
[E ProcessGroupNCCL.cpp:467] [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800376 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800392 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800436 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=75264, NumelOut=602112, Timeout(ms)=1800000) ran for 1800709 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800860 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800869 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800882 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800988 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:852] [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=75264, NumelOut=602112, Timeout(ms)=1800000) ran for 1800709 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=75264, NumelOut=602112, Timeout(ms)=1800000) ran for 1800709 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:852] [Rank 1] NCCL watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800392 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 1] NCCL watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800392 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:852] [Rank 7] NCCL watchdog thread terminated with exception: [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800860 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 7] NCCL watchdog thread terminated with exception: [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800860 milliseconds before timing out.
[2024-02-03 11:48:17,107] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694488
[2024-02-03 11:48:17,108] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694489
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:852] [Rank 2] NCCL watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800436 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:852] [Rank 6] NCCL watchdog thread terminated with exception: [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800376 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 2] NCCL watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800436 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 6] NCCL watchdog thread terminated with exception: [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800376 milliseconds before timing out.
[2024-02-03 11:48:20,395] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694490
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:852] [Rank 3] NCCL watchdog thread terminated with exception: [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800869 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 3] NCCL watchdog thread terminated with exception: [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800869 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:852] [Rank 4] NCCL watchdog thread terminated with exception: [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800988 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 4] NCCL watchdog thread terminated with exception: [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800988 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:852] [Rank 5] NCCL watchdog thread terminated with exception: [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800882 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 5] NCCL watchdog thread terminated with exception: [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800882 milliseconds before timing out.
[2024-02-03 11:48:25,872] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694491
[2024-02-03 11:48:27,258] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694492
[2024-02-03 11:48:27,261] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694493
[2024-02-03 11:48:27,263] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694494
[2024-02-03 11:48:27,265] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694495
[2024-02-03 11:48:27,267] [ERROR] [launch.py:322:sigkill_handler] ['/usr/bin/python', '-u', 'moellava/train/train_mem.py', '--local_rank=7', '--deepspeed', './scripts/zero3_offload.json', '--model_name_or_path', '/export/App/training_platform/PinoModel/mixtral/Mixtral-8x7B-Instruct-v0.1', '--version', 'mixtral', '--data_path', '/mnt/moe/moe/dataset/data_root/train_json/pretrain/valley_llavaimage.json', '--image_folder', '/mnt/moe/moe/dataset/data_root', '--image_tower', '/export/App/training_platform/PinoModel/openai/clip-vit-large-patch14-336', '--image_projector_type', 'mlp2x_gelu', '--video_tower', '/export/App/training_platform/PinoModel/LanguageBind/LanguageBind_Video_merge', '--video_folder', '/mnt/moe/moe/dataset/data_root', '--tune_mm_mlp_adapter', 'True', '--mm_vision_select_layer', '-2', '--mm_use_im_start_end', 'False', '--mm_use_im_patch_token', 'False', '--bf16', 'True', '--output_dir', './checkpoints/llavamixtral-7b-pretrain', '--num_train_epochs', '1', '--per_device_train_batch_size', '16', '--per_device_eval_batch_size', '4', '--gradient_accumulation_steps', '1', '--evaluation_strategy', 'no', '--save_strategy', 'steps', '--save_steps', '2400', '--save_total_limit', '1', '--learning_rate', '1e-3', '--weight_decay', '0.', '--warmup_ratio', '0.03', '--lr_scheduler_type', 'cosine', '--logging_steps', '1', '--tf32', 'True', '--model_max_length', '2048', '--gradient_checkpointing', 'True', '--dataloader_num_workers', '8', '--lazy_preprocess', 'True', '--report_to', 'tensorboard', '--cache_dir', './cache_dir'] exits with return code = -6

During the period when NCCL got stuck, I obtained the point at which the Python process became stuck.:

root@A03-R40-I16-12-8000045:/export/App/training_platform/PinoModel# py-spy dump -p 3261644
Process 3261644: /usr/bin/python -u moellava/train/train_mem.py --local_rank=5 --deepspeed ./scripts/zero3_offload.json --model_name_or_path /export/App/training_platform/PinoModel/mixtral/Mixtral-8x7B-Instruct-v0.1 --version mixtral --data_path /mnt/moe/moe/dataset/data_root/train_json/pretrain/valley_llavaimage.json --image_folder /mnt/moe/moe/dataset/data_root --image_tower /export/App/training_platform/PinoModel/openai/clip-vit-large-patch14-336 --image_projector_type mlp2x_gelu --video_tower /export/App/training_platform/PinoModel/LanguageBind/LanguageBind_Video_merge --video_folder /mnt/moe/moe/dataset/data_root --tune_mm_mlp_adapter True --mm_vision_select_layer -2 --mm_use_im_start_end False --mm_use_im_patch_token False --bf16 True --output_dir ./checkpoints/llavamixtral-7b-pretrain --num_train_epochs 1 --per_device_train_batch_size 16 --per_device_eval_batch_size 4 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 2400 --save_total_limit 1 --learning_rate 1e-3 --weight_decay 0. --warmup_ratio 0.03 --lr_scheduler_type cosine --logging_steps 1 --tf32 True --model_max_length 2048 --gradient_checkpointing True --dataloader_num_workers 8 --lazy_preprocess True --report_to tensorboard --cache_dir ./cache_dir
Python v3.10.12 (/usr/bin/python3.10)

Thread 3261644 (active): "MainThread"
    <listcomp> (deepspeed/runtime/zero/partition_parameters.py:1138)
    _all_gather_dtype (deepspeed/runtime/zero/partition_parameters.py:1138)
    all_gather_coalesced (deepspeed/runtime/zero/partition_parameters.py:1252)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    __all_gather_params_ (deepspeed/runtime/zero/partitioned_param_coordinator.py:458)
    __all_gather_params (deepspeed/runtime/zero/partitioned_param_coordinator.py:429)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    fetch_sub_module (deepspeed/runtime/zero/partitioned_param_coordinator.py:380)
    decorate_context (torch/utils/_contextlib.py:115)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    pre_sub_module_forward_function (deepspeed/runtime/zero/parameter_offload.py:452)
    decorate_context (torch/utils/_contextlib.py:115)
    _pre_forward_module_hook (deepspeed/runtime/zero/parameter_offload.py:340)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _call_impl (torch/nn/modules/module.py:1557)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (transformers/models/clip/modeling_clip.py:263)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (transformers/models/clip/modeling_clip.py:372)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (torch/utils/checkpoint.py:230)
    apply (torch/autograd/function.py:539)
    checkpoint (torch/utils/checkpoint.py:450)
    inner (torch/_dynamo/external_utils.py:17)
    _fn (torch/_dynamo/eval_frame.py:333)
    inner (torch/_compile.py:24)
    forward (transformers/models/clip/modeling_clip.py:622)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (transformers/models/clip/modeling_clip.py:844)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (transformers/models/clip/modeling_clip.py:917)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (clip_encoder.py:50)
    decorate_context (torch/utils/_contextlib.py:115)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    encode_images (moellava/model/llava_arch.py:152)
    prepare_inputs_labels_for_multimodal (moellava/model/llava_arch.py:198)
    forward (llava_mixtral.py:83)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (deepspeed/runtime/engine.py:1842)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _call_impl (torch/nn/modules/module.py:1527)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    compute_loss (transformers/trainer.py:2795)
    training_step (transformers/trainer.py:2772)
    _inner_training_loop (transformers/trainer.py:1868)
    train (transformers/trainer.py:1539)
    train (train.py:1475)
    <module> (train_mem.py:13)
Thread 3262753 (idle): "Thread-1"
    select (selectors.py:416)
    wait (multiprocessing/connection.py:931)
    wait_result_broken_or_wakeup (concurrent/futures/process.py:385)
    run (concurrent/futures/process.py:320)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3264158 (idle): "Thread-2"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3267395 (idle): "Thread-3 (_pin_memory_loop)"
    select (selectors.py:416)
    wait (multiprocessing/connection.py:931)
    _poll (multiprocessing/connection.py:424)
    poll (multiprocessing/connection.py:257)
    get (multiprocessing/queues.py:113)
    do_one_step (torch/utils/data/_utils/pin_memory.py:31)
    _pin_memory_loop (torch/utils/data/_utils/pin_memory.py:54)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268088 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268152 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268153 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268154 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268155 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268156 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268157 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268158 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3303923 (idle)
Thread 3303931 (idle)
Thread 3303916 (idle)
Thread 3303934 (idle)
Thread 3303942 (idle)
Thread 3303945 (idle)
Thread 3303952 (idle)
Thread 3303949 (idle)
awzhgw commented 9 months ago

@tohtana can you help me ???

hanxiaotian commented 8 months ago

same here

jingcangcang commented 8 months ago

same here

hanxiaotian commented 8 months ago

Found the potential cause, some experts during training don't see any token, thus no gradients, all other processes will get stucked. After feed fake gradient to experts that don't see any token, the training goes smooth.

WanqiZhong commented 8 months ago

Found the potential cause, some experts during training don't see any token, thus no gradients, all other processes will get stucked. After feed fake gradient to experts that don't see any token, the training goes smooth.找到了潜在的原因,一些专家在训练期间没有看到任何标记,因此没有梯度,所有其他过程都会被卡住。将假梯度提供给看不到任何标记的专家后,训练就会顺利进行。

Could you please provide an example on how to feed fake gradient to experts? Much appreciated! @hanxiaotian

hanxiaotian commented 8 months ago

Found the potential cause, some experts during training don't see any token, thus no gradients, all other processes will get stucked. After feed fake gradient to experts that don't see any token, the training goes smooth.找到了潜在的原因,一些专家在训练期间没有看到任何标记,因此没有梯度,所有其他过程都会被卡住。将假梯度提供给看不到任何标记的专家后,训练就会顺利进行。

Could you please provide an example on how to feed fake gradient to experts? Much appreciated! @hanxiaotian

something like below modification in HF Mixtral implementation

    for expert_idx in range(self.num_experts):
        expert_layer = self.experts[expert_idx]
        idx, top_x = torch.where(expert_mask[expert_idx])

        if top_x.shape[0] == 0 and self.training:
            if self.training:
                top_x_ = torch.zeros(1).to(hidden_states.device).to(torch.int32)
                top_x_list = top_x_.tolist()
                current_state = hidden_states[None, top_x_list].reshape(
                    -1, hidden_dim
                )
                fake_state = expert_layer(current_state * 0)
                final_hidden_states.index_add_(
                    0, top_x_, fake_state.to(hidden_states.dtype)
                )
            else:
                continue
        else:
            # in torch it is faster to index using lists than torch tensors
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            current_hidden_states = (
                expert_layer(current_state)
                * routing_weights[top_x_list, idx_list, None]
            )

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(
                0, top_x, current_hidden_states.to(hidden_states.dtype)
            )

Hope this can help.