microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
33.91k stars 3.98k forks source link

[BUG] apply_tensor_parallelism() is not executed in Zero3 without self.mpu #4080

Open devamanyu opened 11 months ago

devamanyu commented 11 months ago

Describe the bug

In Hybrid Engine, the apply_tensor_parallelism() is not called when model inference container requires tp > 1 but self.mpu is None. For example, for a large model in Zero3, the apply_tensor_parallelism() is not called.

Log output

To Reproduce https://github.com/microsoft/DeepSpeed/blob/a7fe3bcc353c072846e4f86acff5cbfd758e2ec9/deepspeed/runtime/hybrid_engine.py#L206

To reproduce - call any large model, say Llama 30b using Zero 3 + Hybrid Engine.

wang990099 commented 10 months ago

I meet same issue, too. I try to call apply_tensor_parallelism(), I get wrong attention tensor size, too I find auto_tp.py that tensor parallel copy to original size not tp tensor size,

hxdtest commented 9 months ago
config = {
  "train_batch_size" : 32,
  "train_micro_batch_size_per_gpu": 2,
  "steps_per_print": 10,
  "zero_optimization": {
    "stage": 3,
    "offload_param": {
        "device": "cpu"
    },
    "stage3_param_persistence_threshold": 0
  },
  "fp16":{
    "enabled": True,
    "loss_scale_window": 100
  },

"hybrid_engine": {
    "enabled": True,
    "inference_tp_size": 2},

  "gradient_clipping": 1.0,

}
kwargs = {}
kwargs["config"]  = config

import torch

from transformers import AutoModelForCausalLM
import deepspeed
import argparse
from deepspeed.accelerator import get_accelerator

deepspeed.runtime.utils.see_memory_usage('pre test', force=True)

model = AutoModelForCausalLM.from_pretrained("/ossfs/workspace/opt-350m", trust_remote_code=True).half().to(get_accelerator().device_name())

deepspeed.runtime.utils.see_memory_usage('post test', force=True)
import os 

kwargs["model"]  = model 
m, _, _, _ = deepspeed.initialize(**kwargs)

m.eval()

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("/ossfs/workspace/opt-350m", use_fast=False)

prompt = "Hello, I'm am conscious and"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

generated_ids = m.generate(input_ids)

I meet same issue, too. I disable if self.mpu is not None:

Log output is

│ /opt/conda/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/ds_attention.py:117 │ │ in _merge_qkv │ │ │ │ 114 │ def _merge_qkv(self): │ │ 115 │ │ │ │ 116 │ │ qvkw = DeepSpeedSelfAttention._qkv_buffers[0] │ │ ❱ 117 │ │ qvkw[:self.hidden_size_per_partition, :] = self.attn_qw # type: ignore │ │ 118 │ │ qvkw[self.hidden_size_per_partition:2 self.hidden_size_per_partition, :] = sel │ │ 119 │ │ qvkw[2 self.hidden_size_per_partition:, :] = self.attn_vw # type: ignore │ │ 120 │ │ if self.attn_qb is not None: │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────╯ RuntimeError: The expanded size of the tensor (1024) must match the existing size (0) at non-singleton dimension 1. Target sizes: [512, 1024]. Tensor sizes: [0]`

LSC527 commented 8 months ago

@hxdtest The size [0] tensor seems caused by not properly setting tp weight. Besides disable if self.mpu is not None: , I did this change in split_qkv.py and tp works fine now:

    def attention_qkv_mp(self, mp_replace, reversed_dim=False):
        # Only need to alter
        if self.module.attention.attn_qkvw is None:
            # params = [
            #     (self.module.attention.attn_qw, self.qw),
            #     (self.module.attention.attn_qb, self.qb),
            #     (self.module.attention.attn_kw, self.kw),
            #     (self.module.attention.attn_kb, self.kb),
            #     (self.module.attention.attn_vw, self.vw),
            #     (self.module.attention.attn_vb, self.vb),
            # ]
            # for dst, src in params:
            #     dst = mp_replace.copy(
            #         dst[:self.qw.shape[0] // mp_replace.mp_size], src, int8=reversed_dim,
            #         allocate_tensor=reversed_dim) if src is not None else None
            self.module.attention.attn_qw = mp_replace.copy(self.module.attention.attn_qw[:self.qw.shape[0] // mp_replace.mp_size], self.qw, int8=reversed_dim, allocate_tensor=reversed_dim) if self.qw is not None else None
            self.module.attention.attn_qb = mp_replace.copy(self.module.attention.attn_qb[:self.qw.shape[0] // mp_replace.mp_size], self.qb, int8=reversed_dim, allocate_tensor=reversed_dim) if self.qb is not None else None
            self.module.attention.attn_kw = mp_replace.copy(self.module.attention.attn_kw[:self.qw.shape[0] // mp_replace.mp_size], self.kw, int8=reversed_dim, allocate_tensor=reversed_dim) if self.kw is not None else None
            self.module.attention.attn_kb = mp_replace.copy(self.module.attention.attn_kb[:self.qw.shape[0] // mp_replace.mp_size], self.kb, int8=reversed_dim, allocate_tensor=reversed_dim) if self.kb is not None else None
            self.module.attention.attn_vw = mp_replace.copy(self.module.attention.attn_vw[:self.qw.shape[0] // mp_replace.mp_size], self.vw, int8=reversed_dim, allocate_tensor=reversed_dim) if self.vw is not None else None
            self.module.attention.attn_vb = mp_replace.copy(self.module.attention.attn_vb[:self.qw.shape[0] // mp_replace.mp_size], self.vb, int8=reversed_dim, allocate_tensor=reversed_dim) if self.vb is not None else None
        else:
            super().attention_qkv_mp(mp_replace)

The original code does not make attn weight stays after apply_tensor_parallelism and exiting GatheredParameters context. However, I am not sure if the tp weights are actually correct by doing this.