microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.67k stars 4.04k forks source link

[BUG] Fail to inference with 8bit quantized bloom-3b model, shape mismatch error #2865

Open Oliver-ss opened 1 year ago

Oliver-ss commented 1 year ago

Describe the bug Hi, I am trying to run a bloom3b model with 8bit quantization with two 2080Ti GPUs.

Because there exists no pre-quantized version, I simply change the dtype when deepspeed.init_inference but a shape mismatch error always occurs.

I try to read the code and find that the int8 weight seems to nedd an extra permutation necessary according to the ds_qkv_gemm cuda implementation.

https://github.com/microsoft/DeepSpeed/blob/fd1449c766f8dc0b0d77ef6389934d094b60a889/csrc/transformer/inference/csrc/pt_binding.cpp#L877

Here is the error:

/data/deepspeed_playground/bloom3b/run_deepspeed.py:34 in <module>                      │
│                                                                                          │
│   31 # Test model                                                                        │
│   32 example = "DeepSpeed is a machine learning framework"                               │
│   33 input_ids = tokenizer(example, return_tensors="pt").input_ids.to(model.device)      │
│ ❱ 34 logits = ds_model.generate(input_ids, do_sample=True, max_length=100)               │
│   35 print(f"prediction: \n \n {tokenizer.decode(logits[0].tolist())}")                  │
│   36                                                                                     │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/deepspeed/inference/engine.py:561 in   │
│ _generate                                                                                │
│                                                                                          │
│   558 │   │   │   │   "add your request to: https://github.com/microsoft/DeepSpeed/issue │
│   559 │   │   │   )                                                                      │
│   560 │   │                                                                              │
│ ❱ 561 │   │   return self.module.generate(*inputs, **kwargs)                             │
│   562                                                                                    │
│                                                                                          │
/home/xxxx/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27 in      │
│ decorate_context                                                                         │
│                                                                                          │
│    24 │   │   @functools.wraps(func)                                                     │
│    25 │   │   def decorate_context(*args, **kwargs):                                     │
│    26 │   │   │   with self.clone():                                                     │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                       │
│    28 │   │   return cast(F, decorate_context)                                           │
│    29 │                                                                                  │
│    30 │   def _wrap_generator(self, func):                                               │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/transformers/generation/utils.py:1437  │
│ in generate                                                                              │
│                                                                                          │
│   1434 │   │   │   )                                                                     │
│   1435 │   │   │                                                                         │
│   1436 │   │   │   # 13. run sample                                                      │
│ ❱ 1437 │   │   │   return self.sample(                                                   │
│   1438 │   │   │   │   input_ids,                                                        │
│   1439 │   │   │   │   logits_processor=logits_processor,                                │
│   1440 │   │   │   │   logits_warper=logits_warper,                                      │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/transformers/generation/utils.py:2443  │
│ in sample                                                                                │
│                                                                                          │
│   2440 │   │   │   model_inputs = self.prepare_inputs_for_generation(input_ids, **model_ │
│   2441 │   │   │                                                                         │
│   2442 │   │   │   # forward pass to get next token                                      │
│ ❱ 2443 │   │   │   outputs = self(                                                       │
│   2444 │   │   │   │   **model_inputs,                                                   │
│   2445 │   │   │   │   return_dict=True,                                                 │
│   2446 │   │   │   │   output_attentions=output_attentions,                              │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1110 in     │
│ _call_impl                                                                               │
│                                                                                          │                                                                                  [211/25841]
│   1107 │   │   # this function, and just call forward.                                   │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):           │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                 │
│   1111 │   │   # Do not call functions when jit is used                                  │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                     │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                        │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/transformers/models/bloom/modeling_blo │
│ om.py:900 in forward                                                                     │
│                                                                                          │
│    897 │   │                                                                             │
│    898 │   │   return_dict = return_dict if return_dict is not None else self.config.use │
│    899 │   │                                                                             │
│ ❱  900 │   │   transformer_outputs = self.transformer(                                   │
│    901 │   │   │   input_ids,                                                            │
│    902 │   │   │   past_key_values=past_key_values,                                      │
│    903 │   │   │   attention_mask=attention_mask,                                        │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1110 in     │
│ _call_impl                                                                               │
│                                                                                          │
│   1107 │   │   # this function, and just call forward.                                   │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):           │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                 │
│   1111 │   │   # Do not call functions when jit is used                                  │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                     │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                        │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/transformers/models/bloom/modeling_blo │
│ om.py:782 in forward                                                                     │
│                                                                                          │
│    779 │   │   │   │   │   head_mask[i],                                                 │
│    780 │   │   │   │   )                                                                 │
│    781 │   │   │   else:                                                                 │
│ ❱  782 │   │   │   │   outputs = block(                                                  │
│    783 │   │   │   │   │   hidden_states,                                                │
│    784 │   │   │   │   │   layer_past=layer_past,                                        │
│    785 │   │   │   │   │   attention_mask=causal_mask,                                   │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1110 in     │                                                                                  [169/25841]
│ _call_impl                                                                               │
│                                                                                          │
│   1107 │   │   # this function, and just call forward.                                   │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):           │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                 │
│   1111 │   │   # Do not call functions when jit is used                                  │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                     │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                        │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/deepspeed/model_implementations/transf │
│ ormers/ds_transformer.py:157 in forward                                                  │
│                                                                                          │
│   154 │   │   │   input = input.half()                                                   │
│   155 │   │   with torch.no_grad():                                                      │
│   156 │   │   │   attention_output, key, value, context_outputtn_ctx, inp_norm = \       │
│ ❱ 157 │   │   │   │   │   │   │   │   │    self.attention(input,                         │
│   158 │   │   │   │   │   │   │   │   │   │   │     input_mask,                          │
│   159 │   │   │   │   │   │   │   │   │   │   │     head_mask,                           │
│   160 │   │   │   │   │   │   │   │   │   │   │     layer_past,                          │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1110 in     │
│ _call_impl                                                                               │
│                                                                                          │
│   1107 │   │   # this function, and just call forward.                                   │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):           │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                 │
│   1111 │   │   # Do not call functions when jit is used                                  │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                     │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                        │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/ds │
│ _attention.py:128 in forward                                                             │
│                                                                                          │
│   125 │   │   │   │   num_layers=DeepSpeedSelfAttention.num_layers,                      │
│   126 │   │   │   │   num_heads=self.num_attention_heads_per_partition)                  │
│   127 │   │                                                                              │
│ ❱ 128 │   │   context_layer, key_layer, value_layer = self.compute_attention(            │
│   129 │   │   │   qkv_out=qkv_out,                                                       │
│   130 │   │   │   input_mask=input_mask,                                                 │
│   131 │   │   │   layer_past=layer_past,                                                 │
│                                                                                          │
│ /home/xxxx/.local/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/ds │
│ _attention.py:202 in compute_attention                                                   │
│                                                                                          │
│   199 │   │   new_tensor_shape = mixed_x_layer.size()[:-1] + (                           │
│   200 │   │   │   self.num_attention_heads_per_partition,                                │
│   201 │   │   │   3 * head_dim)                                                          │
│ ❱ 202 │   │   mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)                      │
│   203 │   │                                                                              │
│   204 │   │   query_layer, key_layer, value_layer = self._split_tensor_along_last_dim(mi │
│   205                                                                                    │
╰──────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: shape '[1, 7, 16, 240]' is invalid for input of size 17920

Here is my code:

import os
import torch
import numpy as np
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import deepspeed

world_size = int(os.getenv('WORLD_SIZE', '1'))

# hide generation warnings
transformers.logging.set_verbosity_error()

tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-3b")
model = AutoModelForCausalLM.from_pretrained(
    "bigscience/bloom-3b", torch_dtype=torch.float16
)

# init deepspeed inference engine
ds_model = deepspeed.init_inference(
    model=model,  # Transformers models
    mp_size=world_size,  # Number of GPU
    # dtype=torch.float16,  # dtype of the weights (fp16)
    dtype=torch.int8,  # dtype of the weights (fp16)
    replace_method="auto",  # Lets DS autmatically identify the layer to replace
    replace_with_kernel_inject=True,  # replace the model with the kernel injector
)
print(f"model is loaded on device {ds_model.module.device}")

# Test model
example = "DeepSpeed is a machine learning framework"
input_ids = tokenizer(example, return_tensors="pt").input_ids.to(model.device)
logits = ds_model.generate(input_ids, do_sample=True, max_length=100)
print(f"prediction: \n \n {tokenizer.decode(logits[0].tolist())}")

To Reproduce deepspeed --num_gpu 2 run_deepspeed.py

Expected behavior A clear and concise description of what you expected to happen.

ds_report output torch version .................... 1.11.0+cu113 deepspeed info ................... 0.8.2+8be8c012, 8be8c012, master torch cuda version ............... 11.3 torch hip version ................ None nvcc version ..................... 11.4 deepspeed wheel compiled w. ...... torch 1.11, cuda 11.3

Screenshots If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

crazycth commented 1 year ago

I meet the same problem

Tracin commented 1 year ago

Meet the same problem, I wonder it is the way we use Deepspeed wrong (I mean load fp16 ckpt and set dtype=int8 in init_inference) or it is just bugs in it ?

trianxy commented 1 year ago

Hey @Oliver-ss (cc @molly-smith ) - were you able to debug this problem?

Similarly to @Tracin , I am also wondering whether we are missing some insights into how we are supposed to quantize models to int8 and run inference on them with DeepSpeed.

Tracin commented 1 year ago

Hey @Oliver-ss (cc @molly-smith ) - were you able to debug this problem?

Similarly to @Tracin , I am also wondering whether we are missing some insights into how we are supposed to quantize models to int8 and run inference on them with DeepSpeed.

I figured it out few days ago. There are few bugs in the code: weight should be transposed before quantized and quantizer group can not be too large. Not sure if they are fixed in newer version.

trianxy commented 1 year ago

Thanks for the info @Tracin - can you share what % of speed up you saw (at bloom-3b, I guess?) when using int8 instead of fp16?

Tracin commented 1 year ago

@trianxy int8 actually cost more time since weights have to be dequantized

trianxy commented 1 year ago

@trianxy int8 actually cost more time since weights have to be dequantized

Maybe you are referring to something different? I was thinking about inference speed (milliseconds per generated token) of the quantized model. In that case, as far as I understand, no dequantization is needed.

Tracin commented 1 year ago

@trianxy I mean int8 costs more time for dequantization in Deepspeed for now on bloom. Besides, I saw int8 kernel inference with no dequantization on 176B models, which has the 1.0x - 0.5x speed compared to fp16, since input quantization costs time.

molly-smith commented 1 year ago

this issue is also being looking into here https://github.com/microsoft/DeepSpeed/issues/2876

molly-smith commented 1 year ago

https://github.com/microsoft/DeepSpeed/issues/2923