Open devamanyu opened 11 months ago
I meet same issue, too. I try to call apply_tensor_parallelism(), I get wrong attention tensor size, too I find auto_tp.py that tensor parallel copy to original size not tp tensor size,
config = {
"train_batch_size" : 32,
"train_micro_batch_size_per_gpu": 2,
"steps_per_print": 10,
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": "cpu"
},
"stage3_param_persistence_threshold": 0
},
"fp16":{
"enabled": True,
"loss_scale_window": 100
},
"hybrid_engine": {
"enabled": True,
"inference_tp_size": 2},
"gradient_clipping": 1.0,
}
kwargs = {}
kwargs["config"] = config
import torch
from transformers import AutoModelForCausalLM
import deepspeed
import argparse
from deepspeed.accelerator import get_accelerator
deepspeed.runtime.utils.see_memory_usage('pre test', force=True)
model = AutoModelForCausalLM.from_pretrained("/ossfs/workspace/opt-350m", trust_remote_code=True).half().to(get_accelerator().device_name())
deepspeed.runtime.utils.see_memory_usage('post test', force=True)
import os
kwargs["model"] = model
m, _, _, _ = deepspeed.initialize(**kwargs)
m.eval()
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("/ossfs/workspace/opt-350m", use_fast=False)
prompt = "Hello, I'm am conscious and"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
generated_ids = m.generate(input_ids)
I meet same issue, too. I disable if self.mpu is not None:
Log output is
│ /opt/conda/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/ds_attention.py:117 │ │ in _merge_qkv │ │ │ │ 114 │ def _merge_qkv(self): │ │ 115 │ │ │ │ 116 │ │ qvkw = DeepSpeedSelfAttention._qkv_buffers[0] │ │ ❱ 117 │ │ qvkw[:self.hidden_size_per_partition, :] = self.attn_qw # type: ignore │ │ 118 │ │ qvkw[self.hidden_size_per_partition:2 self.hidden_size_per_partition, :] = sel │ │ 119 │ │ qvkw[2 self.hidden_size_per_partition:, :] = self.attn_vw # type: ignore │ │ 120 │ │ if self.attn_qb is not None: │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────╯ RuntimeError: The expanded size of the tensor (1024) must match the existing size (0) at non-singleton dimension 1. Target sizes: [512, 1024]. Tensor sizes: [0]`
@hxdtest The size [0] tensor seems caused by not properly setting tp weight.
Besides disable if self.mpu is not None:
, I did this change in split_qkv.py
and tp works fine now:
def attention_qkv_mp(self, mp_replace, reversed_dim=False):
# Only need to alter
if self.module.attention.attn_qkvw is None:
# params = [
# (self.module.attention.attn_qw, self.qw),
# (self.module.attention.attn_qb, self.qb),
# (self.module.attention.attn_kw, self.kw),
# (self.module.attention.attn_kb, self.kb),
# (self.module.attention.attn_vw, self.vw),
# (self.module.attention.attn_vb, self.vb),
# ]
# for dst, src in params:
# dst = mp_replace.copy(
# dst[:self.qw.shape[0] // mp_replace.mp_size], src, int8=reversed_dim,
# allocate_tensor=reversed_dim) if src is not None else None
self.module.attention.attn_qw = mp_replace.copy(self.module.attention.attn_qw[:self.qw.shape[0] // mp_replace.mp_size], self.qw, int8=reversed_dim, allocate_tensor=reversed_dim) if self.qw is not None else None
self.module.attention.attn_qb = mp_replace.copy(self.module.attention.attn_qb[:self.qw.shape[0] // mp_replace.mp_size], self.qb, int8=reversed_dim, allocate_tensor=reversed_dim) if self.qb is not None else None
self.module.attention.attn_kw = mp_replace.copy(self.module.attention.attn_kw[:self.qw.shape[0] // mp_replace.mp_size], self.kw, int8=reversed_dim, allocate_tensor=reversed_dim) if self.kw is not None else None
self.module.attention.attn_kb = mp_replace.copy(self.module.attention.attn_kb[:self.qw.shape[0] // mp_replace.mp_size], self.kb, int8=reversed_dim, allocate_tensor=reversed_dim) if self.kb is not None else None
self.module.attention.attn_vw = mp_replace.copy(self.module.attention.attn_vw[:self.qw.shape[0] // mp_replace.mp_size], self.vw, int8=reversed_dim, allocate_tensor=reversed_dim) if self.vw is not None else None
self.module.attention.attn_vb = mp_replace.copy(self.module.attention.attn_vb[:self.qw.shape[0] // mp_replace.mp_size], self.vb, int8=reversed_dim, allocate_tensor=reversed_dim) if self.vb is not None else None
else:
super().attention_qkv_mp(mp_replace)
The original code does not make attn weight stays after apply_tensor_parallelism
and exiting GatheredParameters
context.
However, I am not sure if the tp weights are actually correct by doing this.
Describe the bug
In Hybrid Engine, the
apply_tensor_parallelism()
is not called when model inference container requires tp > 1 but self.mpu is None. For example, for a large model in Zero3, theapply_tensor_parallelism()
is not called.Log output
To Reproduce https://github.com/microsoft/DeepSpeed/blob/a7fe3bcc353c072846e4f86acff5cbfd758e2ec9/deepspeed/runtime/hybrid_engine.py#L206
To reproduce - call any large model, say Llama 30b using Zero 3 + Hybrid Engine.