microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.28k stars 4.09k forks source link

[BUG] CUDA error with INT 8 inference #1788

Open gsujankumar opened 2 years ago

gsujankumar commented 2 years ago

Describe the bug I am trying to get started with implementing INT 8 inference on Deepspeed. But I am running into RuntimeError: CUDA error: an illegal memory access was encountered .

To Reproduce

Code:

I am interested in implementing INT8 inference with GPT2 styled models, the code I am running is the following:

import os
import torch
import deepspeed
from transformers import GPT2Tokenizer, GPT2LMHeadModel

local_rank = int(os.getenv('LOCAL_RANK', '-1'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
model = GPT2LMHeadModel.from_pretrained('gpt2-large').to(local_rank)

model = deepspeed.init_inference(
    model,
    mp_size=world_size,
    dtype=torch.int8,
    replace_method = 'auto',
    quantization_setting=2,
    replace_with_kernel_inject=True
)

input_ids = tokenizer.encode('Example context for testing ',return_tensors='pt')
input_ids = input_ids.to(local_rank)
outputs = model(input_ids)
print(outputs)

I am running this with

I noticed few bugs blocking INT8 inference and I made the following changes to the source code:

  1. around line 132 in deepspeed/runtime/weight_quantizer.py
            for key in range(len(keys)):
                #if self.mlp_extra_grouping and is_mlp(keys[key]): # line removed
                if self.mlp_extra_grouping and self.is_mlp(keys[key])>=2: # line added

as is_mlp was not defined

  1. around line 161 in deepspeed/runtime/weight_quantizer.py
        else:
            for plcy in replace_policies:
                _ = plcy(None) # line added
                policy.update({plcy._orig_layer_class: (quantize_fn, plcy)})
  1. at line 282 in deepspeed/ops/inference/transformer_inference.py
            # context_layer, key_layer, value_layer = compute_attention(qkv_out) # line removed
            context_layer, key_layer, value_layer = compute_attention(qkv_out, input_mask) # line added

Expected behavior Output meaningful logits

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
sparse_attn ............ [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/torch']
torch version .................... 1.10.2+cu113
torch cuda version ............... 11.3
nvcc version ..................... 11.3
deepspeed install path ........... ['/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.5.9, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.9, cuda 11.1

Outputs While the code runs error free with dtype=torch.float and dtype=torch.half I am running into errors with dtype=torch.int8

running CUDA_VISIBLE_DEVICES=1 CUDA_LAUNCH_BLOCKING=1 deepspeed gpt_example.py results in the following output:

[2022-02-23 16:00:20,231] [WARNING] [runner.py:132:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
Detected CUDA_VISIBLE_DEVICES=1: setting --include=localhost:1
[2022-02-23 16:00:20,308] [INFO] [runner.py:398:main] cmd = /home/ec2-user/anaconda3/envs/pretrain-vector/bin/python3.8 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMV19 --master_addr=127.0.0.1 --master_port=29500 gpt_example.py
[2022-02-23 16:00:21,307] [INFO] [launch.py:80:main] WORLD INFO DICT: {'localhost': [1]}
[2022-02-23 16:00:21,308] [INFO] [launch.py:86:main] nnodes=1, num_local_procs=1, node_rank=0
[2022-02-23 16:00:21,308] [INFO] [launch.py:99:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2022-02-23 16:00:21,308] [INFO] [launch.py:100:main] dist_world_size=1
[2022-02-23 16:00:21,308] [INFO] [launch.py:102:main] Setting CUDA_VISIBLE_DEVICES=1
[2022-02-23 16:00:35,424] [INFO] [logging.py:69:log_dist] [Rank -1] DeepSpeed info: version=0.5.9, git-hash=unknown, git-branch=unknown
[2022-02-23 16:00:35,424] [INFO] [engine.py:127:_init_quantization_setting] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 2
Using /home/ec2-user/.cache/torch_extensions/py38_cu113 as PyTorch extensions root...
/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/torch/utils/cpp_extension.py:295: UserWarning: 

                               !! WARNING !!

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Your compiler (c++) is not compatible with the compiler Pytorch was
built with for this platform, which is g++ on linux. Please
use g++ to to compile your extension. Alternatively, you may
compile PyTorch from source using c++, and then you can also use
c++ to compile your extension.

See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
with compiling PyTorch from source.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

                              !! WARNING !!

  warnings.warn(WRONG_COMPILER_WARNING.format(
Detected CUDA files, patching ldflags
Emitting ninja build file /home/ec2-user/.cache/torch_extensions/py38_cu113/transformer_inference/build.ninja...
Building extension module transformer_inference...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module transformer_inference...
Time to load transformer_inference op: 0.3160672187805176 seconds
DeepSpeed Transformer Inference config is  {'layer_id': 0, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 1, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 2, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 3, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 4, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 5, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 6, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 7, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 8, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 9, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 10, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 11, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 12, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 13, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 14, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 15, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 16, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 17, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 18, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 19, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 20, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 21, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 22, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 23, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 24, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 25, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 26, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 27, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 28, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 29, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 30, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 31, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 32, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 33, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 34, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
DeepSpeed Transformer Inference config is  {'layer_id': 35, 'hidden_size': 1280, 'intermediate_size': 5120, 'heads': 20, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-12, 'mp_size': 1, 'q_int8': True, 'scale_attention': True, 'specialized_mode': False, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'return_tuple': True}
[2022-02-23 16:00:36,481] [INFO] [engine.py:91:__init__] Place model to device: 0

!!!! kernel execution error. (batch: 20, m: 64, n: 5, k: 5, error: 13) 
Traceback (most recent call last):
  File "DeepSpeedExperiment.py", line 24, in <module>
    outputs = model(input_ids)
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 246, in forward
    outputs = self.module(*inputs, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1120, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1047, in forward
    transformer_outputs = self.transformer(
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 890, in forward
    outputs = block(
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 587, in forward
    attention_output = self.attention(input,
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 355, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 291, in forward
    output, key_layer, value_layer, context_layer = selfAttention_int8()
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 282, in selfAttention_int8
    context_layer, key_layer, value_layer = compute_attention(qkv_out, input_mask)
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 236, in compute_attention
    context_layer = _transpose_for_context(context_layer)
  File "/home/ec2-user/anaconda3/envs/pretrain-vector/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 144, in _transpose_for_context
    x = x.permute(0, 2, 1, 3).contiguous()
RuntimeError: CUDA error: an illegal memory access was encountered
[2022-02-23 16:00:37,334] [INFO] [launch.py:131:sigkill_handler] Killing subprocess 52992
[2022-02-23 16:00:37,334] [ERROR] [launch.py:137:sigkill_handler] ['/home/ec2-user/anaconda3/envs/pretrain-vector/bin/python3.8', '-u', 'DeepSpeedExperiment.py', '--local_rank=0'] exits with return code = 1

Launcher context Using deepspeed launcher

RezaYazdaniAminabadi commented 2 years ago

Hi @gsujankumar

Thanks for pointing this issue. I will look into this and send a fix soon. Best, Reza

gsujankumar commented 2 years ago

Hey Reza,

Thanks for looking into this. Can you provide me a rough ETA on the fix?

RezaYazdaniAminabadi commented 2 years ago

Sorry for the delay, @gsujankumar I will try to make it work by the end of this week or early next week.

gsujankumar commented 2 years ago

@RezaYazdaniAminabadi Did you find what is causing the issues?

I tried to implement INT8 inference with a MoQ trained BERT model. I noticed that inference with INT8 did not seem to work outside deepspeed>=0.4.0<0.4.3 and with transformers==5.5.2. With versions beyond these I am either running into errors/accuracy issues.

RezaYazdaniAminabadi commented 2 years ago

Hi @gsujankumar,

Sorry for the delay on this line. I was so busy with some internal projects. Yes, you are right that MoQ was mainly targeted for the older version of transformers. The error you are seeing above is coming from the GeMM, and can be also related to the quantization happening before this operation. I will let you know if this is fixed soon.

Thanks, Reza

gsujankumar commented 2 years ago

Hey @RezaYazdaniAminabadi I was able to resolve the issue by casting the FP16 inputs to FP32 in the compute_attention method in deepspeed/ops/inference/transformer_inference.py as follows:

Around line 154, add the following lines

            if config.q_int8:
                qkv_out = qkv_out.float()

and cast the outputs of the function to FP16 before returning by adding the following lines:

            if config.q_int8:
                context_layer = context_layer.half()
                key_layer = key_layer.half()
                value_layer = value_layer.half()

            return context_layer, key_layer, value_layer

It looks like the method was using FP32 kernels, but still use FP16 data. Computing with FP32 inputs resolved nans that resolved CUDA errors.

I am noticing good accuracy with quantization_setting=1, but with any higher grouping the accuracy is dropping. This is counterintuitive. Is there something that might be a miss?

RezaYazdaniAminabadi commented 2 years ago

Hi @gsujankumar,

I am happy you could resolve the issue. Can you please make a PR and add the fix? For the larger number of groups, yes you are right you should get better accuracy, not worse. Are you using our quantizer kernels when quantizing this model? If so, I can double-check if there is any accuracy issue with that. Thanks, Reza

gsujankumar commented 2 years ago

Sure, I will create a PR soon.

Yes, we are using the quantizer kernels from DeepSpeed. Can you check if there are any issues with groups?