microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.46k stars 4.12k forks source link

[BUG] Deepspeed Crashes when using MoE, Stage 2 Offload with DeepSpeedCPUAdam #5203

Open KyleMylonakisProtopia opened 8 months ago

KyleMylonakisProtopia commented 8 months ago

Describe the bug When performing a training run with a model with Mixture of Experts (MoE) layers using stage 2 offload with the DeepSpeedCPUAdam optimizer, during the parameter update step the following runtime error is thrown.

│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/lightning/fabric/wrappers.py:92 in step                                                                                                                                                                │
│                                                                                                                                                                                                                                                                        │
│ ❱  92 │   │   output = self._strategy.optimizer_step(                                                                                                                                                                                                                  │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/lightning/fabric/strategies/strategy.py:206 in optimizer_step                                                                                                                                          │
│                                                                                                                                                                                                                                                                        │
│ ❱ 206 │   │   return self.precision.optimizer_step(optimizer, **kwargs)                                                                                                                                                                                                │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/lightning/fabric/plugins/precision/deepspeed.py:100 in optimizer_step                                                                                                                                  │
│                                                                                                                                                                                                                                                                        │
│ ❱ 100 │   │   return optimizer.step(**kwargs)                                                                                                                                                                                                                          │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/deepspeed/runtime/zero/stage_1_and_2.py:1842 in step                                                                                                                                                   │
│                                                                                                                                                                                                                                                                        │
│ ❱ 1842 │   │   scaled_global_grad_norm = self.scaled_global_norm()                                                                                                                                                                                                     │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/deepspeed/runtime/zero/stage_1_and_2.py:1786 in scaled_global_norm                                                                                                                                     │
│                                                                                                                                                                                                                                                                        │
│ ❱ 1786 │   │   │   self._average_expert_grad_norms(norm_groups)                                                                                                                                                                                                        │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/deepspeed/runtime/zero/stage_1_and_2.py:1949 in _average_expert_grad_norms                                                                                                                             │
│                                                                                                                                                                                                                                                                        │
│ ❱ 1949 │   │   │   │   dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i])                                                                                                                                                                        │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/deepspeed/comm/comm.py:117 in log_wrapper                                                                                                                                                              │
│                                                                                                                                                                                                                                                                        │
│ ❱ 117 │   │   │   return func(*args, **kwargs)                                                                                                                                                                                                                         │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/deepspeed/comm/comm.py:496 in all_reduce                                                                                                                                                               │
│                                                                                                                                                                                                                                                                        │
│ ❱ 496 │   return cdb.all_reduce(tensor, op, group, async_op)                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:489 in _fn                                                                                                                                                                 │
│                                                                                                                                                                                                                                                                        │
│ ❱  489 │   │   │   │   return fn(*args, **kwargs)                                                                                                                                                                                                                      │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/deepspeed/comm/torch.py:159 in all_reduce                                                                                                                                                              │
│                                                                                                                                                                                                                                                                        │
│ ❱ 159 │   │   return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=                                                                                                                                                                         │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/torch/distributed/c10d_logger.py:72 in wrapper                                                                                                                                                         │
│                                                                                                                                                                                                                                                                        │
│ ❱ 72 │   │   │   return func(*args, **kwargs)                                                                                                                                                                                                                          │
│                                                                                                                                                                                                                                                                        │
│ /home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py:1992 in all_reduce                                                                                                                                               │
│                                                                                                                                                                                                                                                                        │
│ ❱ 1992 │   work = group.allreduce([tensor], opts)                                                                                                                                                                                                                      │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: No backend type associated with device type cpu

When using a ep_size=1 for the expert layers, the call to self._average_expert_grad_norms(norm_groups) is not necessary and commenting this out resolves the issue. This of course is not a general solution for ep_size > 1, however in my case it would be sufficient to continue my work.

To Reproduce Steps to reproduce the behavior:

  1. Create a model with a mixture of experts layer present.
  2. Use a deepspeed config with stage 2 offload, such as
    config={
        "zero_optimization": {
            "stage": 2,
            "offload_optimizer": {"device": "cpu", "pin_memory": True},
            "overlap_comm": True,
            "contiguous_gradients": True,
            "allgather_bucket_size": 5e8,
            "reduce_bucket_size": 5e8,
            "allgather_partitions": True,
            "reduce_scatter": True,
            # "round_robin_gradients": True,
        }
    }
  3. Use the DeepSpeedCPUAdam optimizer for efficient CPU offload
  4. Train the model and perform an update step.

Expected behavior Model training should occur with no issues or errors thrown.

ds_report output

[2024-02-27 09:16:20,976] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/torch']
torch version .................... 2.2.0+cu121
deepspeed install path ........... ['/home/kyle/.conda/envs/llama2-chat/lib/python3.11/site-packages/deepspeed']
deepspeed info ................... 0.13.3, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.2, cuda 12.1
shared memory (/dev/shm) size .... 125.77 GB

Screenshots N/A

System info (please complete the following information):

Launcher context Pytorch Lightning

Docker context Bare metal.

Additional context I have ep_size=1 for my mixture of expert layers, so this bug is totally avoidable by just not having the all reduce step.

RezaYazdaniAminabadi commented 8 months ago

Hi @KyleMylonakisProtopia Please give this PR a try, hopefully it resolves the issue. Best, Reza

KyleMylonakisProtopia commented 8 months ago

That PR seems to resolve the issue. Thanks for looking at it!

RezaYazdaniAminabadi commented 8 months ago

@tjruwase, let's please close this and merge the PR :)