microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.07k stars 3.99k forks source link

[BUG] bs=1 mp_size=8 OPT/LLaMA inference CUDA error: an illegal memory access was encountered #3758

Closed chhzh123 closed 1 year ago

chhzh123 commented 1 year ago

Describe the bug It seems the latest master branch has some problems with the inference kernels when running with multiple GPUs. After I tested several models including OPT/LLaMA/Vicuna, none of them can be correctly executed with batch size = 1.

File "/opt/conda/envs/pt20/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    "args": f"{args}, {kwargs}",
  File "/opt/conda/envs/pt20/lib/python3.9/site-packages/torch/_tensor.py", line 426, in __repr__
    return _str_intern(self, tensor_contents=tensor_contents)
  File "/opt/conda/envs/pt20/lib/python3.9/site-packages/torch/_tensor_str.py", line 567, in _str_intern
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/DeepSpeed/deepspeed/model_implementations/transformers/ds_transformer.py", line 157, in forward
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
    layer_outputs = decoder_layer(      File "/opt/conda/envs/pt20/lib/python3.9/site-packages/torch/_tensor_str.py", line 636, in _str

return forward_call(*args, **kwargs)  File "/opt/conda/envs/pt20/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl

  File "/home/ubuntu/DeepSpeed/deepspeed/ops/transformer/inference/ds_attention.py", line 166, in forward
    self.attention(input,
  File "/opt/conda/envs/pt20/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    tensor_str = _tensor_str(self, indent)
  File "/opt/conda/envs/pt20/lib/python3.9/site-packages/torch/_tensor_str.py", line 309, in _tensor_str
        return forward_call(*args, **kwargs)return forward_call(*args, **kwargs)

      File "/home/ubuntu/DeepSpeed/deepspeed/ops/transformer/inference/ds_attention.py", line 166, in forward
  File "/home/ubuntu/DeepSpeed/deepspeed/ops/transformer/inference/ds_attention.py", line 166, in forward
dist.all_reduce(output, group=self.mp_group)
  File "/home/ubuntu/DeepSpeed/deepspeed/comm/comm.py", line 116, in log_wrapper
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/DeepSpeed/deepspeed/model_implementations/transformers/ds_transformer.py", line 157, in forward
        return _str_intern(self, tensor_contents=tensor_contents)        self = self.float()dist.all_reduce(output, group=self.mp_group)
dist.all_reduce(output, group=self.mp_group)    

  File "/opt/conda/envs/pt20/lib/python3.9/site-packages/torch/_tensor_str.py", line 567, in _str_intern

return func(*args, **kwargs)RuntimeError
  File "/home/ubuntu/DeepSpeed/deepspeed/comm/comm.py", line 116, in log_wrapper
      File "/home/ubuntu/DeepSpeed/deepspeed/comm/comm.py", line 116, in log_wrapper
:   File "/home/ubuntu/DeepSpeed/deepspeed/comm/comm.py", line 480, in all_reduce
self.attention(input,CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

To Reproduce Steps to reproduce the behavior:

  1. Simple inference script to reproduce
    
    import os

import torch import torch.distributed as dist import deepspeed from transformers import OPTModel, AutoConfig

dist.init_process_group("nccl", world_size=int(os.environ["WORLD_SIZE"]))

config = AutoConfig.from_pretrained("facebook/opt-1.3b") config.use_cache = False mod = OPTModel(config) mod.eval() mod.to(torch.float16) bs, seq_len = 1, 2048 input_ids = torch.ones( bs, seq_len, dtype=torch.long, device=f"cuda:{dist.get_rank()}" )

Initialize the DeepSpeed-Inference engine

ds_engine = deepspeed.init_inference( mod, mp_size=dist.get_world_size(), dtype=torch.float16, checkpoint=None, replace_with_kernel_inject=True, ) mod = ds_engine.module mod(input_ids)


2. What packages are required and their versions
```bash
PyTorch: 2.0.1
Transformers: 4.28.1
DeepSpeed: 3f5e49
  1. How to run the script
    deepspeed --num_gpus 8 test_ds_bug.py

Expected behavior This script should run without encountering errors.

ds_report output Please run ds_report to give us details about your setup.

[2023-06-15 22:47:46,589] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-06-15 22:47:46,963] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/conda/envs/pt20/lib/python3.9/site-packages/torch']
torch version .................... 2.0.1+cu117
deepspeed install path ........... ['/home/ubuntu/DeepSpeed/deepspeed']
deepspeed info ................... 0.9.4+4ebc22fb, 4ebc22fb, pipe
torch cuda version ............... 11.7
torch hip version ................ None
nvcc version ..................... 11.7
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.7

System info (please complete the following information):

chhzh123 commented 1 year ago

Actually, it seems the problem comes with long sequence lengths. After I changed it to 1024, it can run without getting into errors.

RezaYazdaniAminabadi commented 1 year ago

Actually, it seems the problem comes with long sequence lengths. After I changed it to 1024, it can run without getting into errors.

Hi @chhzh123, Yes, that is the default max_out_tokens that we reserve as the KV-cache and if you want to produce more tokens, you need to increase it, which you can simply do that by passing the max_out_tokens=2048, at the init_inference call. Here is the config for this parameter: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/inference/config.py#L246

Thanks, Reza

chhzh123 commented 1 year ago

@RezaYazdaniAminabadi Thanks for your quick reply! It worked after I changed the max_out_tokens to 2048, but I think it would be better to error out in this case. Otherwise, users won't be able to know what the exact issue is.

Also, what would be the correct way to load large models for inference? I tried the following code, but it failed when initializing the inference engine.

with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
        mod = OPTModel(config)
  File "/home/ubuntu/DeepSpeed/deepspeed/module_inject/replace_module.py", line 500, in _replace_module
    _replace_module(child, name, class_name)
  File "/home/ubuntu/DeepSpeed/deepspeed/module_inject/replace_module.py", line 496, in _replace_module
    setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name,
  File "/home/ubuntu/DeepSpeed/deepspeed/module_inject/replace_module.py", line 401, in _replace
    data = mp_replace.copy(new_weight, child.weight.data)
  File "/home/ubuntu/DeepSpeed/deepspeed/module_inject/replace_module.py", line 96, in copy
    assert not dst.data.is_meta  # the torch.Tensor.copy_ method used below will silently fail on meta tensors
AssertionError
RezaYazdaniAminabadi commented 1 year ago

Hi @chhzh123,

Thanks for the suggestion. Yes, we need to make this more informative. Also, for loading large model, I suggest that you use this script that we have on the DSE side for the inference: https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/text-generation/inference-test.py Please let me know if you can load your model using this. Thanks, Reza