microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.46k stars 4.12k forks source link

[BUG] RuntimeError encountered when generating tokens from a DeepSpeedHybridEngine initialized with 4-bit quantization. #5630

Open Atry opened 5 months ago

Atry commented 5 months ago

Describe the bug

I got the error RuntimeError: The expanded size of the tensor (2048) must match the existing size (1179648) at non-singleton dimension 1. Target sizes: [2048, 2048]. Tensor sizes: [1179648] when trying to run deepspeed_hybrid_engine.generate when the DeepSpeedHybridEngine is initialized with 4-bit quantization.

Log output

See https://gist.github.com/Atry/4ebf4e6208a2a3628f65c85a40f9c49d

To Reproduce Steps to reproduce the behavior: Run the following Python script:

from typing import cast
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from deepspeed.module_inject.containers.llama import LLAMALayerPolicy
from functools import wraps

if not getattr(LLAMALayerPolicy, "is_get_hidden_heads_patched", False):
    # Apply the monkey patch copied from https://github.com/microsoft/DeepSpeed/pull/5624

    @wraps(LLAMALayerPolicy.get_hidden_heads)
    def patched_get_hidden_heads(self: LLAMALayerPolicy) -> tuple[int, int, float, int]:
        client_module = cast(LlamaDecoderLayer, self.client_module)
        hidden_heads = (
            client_module.self_attn.q_proj.in_features,
            client_module.self_attn.num_heads,
            client_module.input_layernorm.variance_epsilon,
            client_module.mlp.gate_proj.out_features,
        )
        return hidden_heads

    LLAMALayerPolicy.get_hidden_heads = patched_get_hidden_heads
    setattr(LLAMALayerPolicy, "is_get_hidden_heads_patched", True)

from os import environ

rank = 0
environ["RANK"] = str(rank)

local_rank = 0
environ["LOCAL_RANK"] = str(local_rank)

world_size = 1
environ["WORLD_SIZE"] = str(world_size)

from deepspeed import DeepSpeedHybridEngine

deepspeed_config = {
    "zero_optimization": {
        "load_from_fp32_weights": False,
        "stage": 3,
        "zero_quantized_weights": True,
        "zero_quantized_nontrainable_weights": True,
    },
    "train_micro_batch_size_per_gpu": 1,
    "bf16": {"enabled": True},
    "weight_quantization": {
        "quantized_initialization": {
            "num_bits": 4,
            "group_size": 64,
            "group_dim": 1,
            "symmetric": False,
        }
    },
}

from transformers.integrations.deepspeed import HfDeepSpeedConfig

hf_deepspeed_config = HfDeepSpeedConfig(deepspeed_config)

import deepspeed.comm

deepspeed.comm.init_distributed(
    dist_backend="nccl",
    rank=rank,
    world_size=world_size,
    auto_mpi_discovery=False,
    init_method=f"tcp://127.0.0.1:9999",
)

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "kevin009/babyllama-v0.6",
    torch_dtype=torch.bfloat16,
    use_flash_attention_2=True,
)

from deepspeed.runtime.config import DeepSpeedConfig

deepspeed_hybrid_engine = DeepSpeedHybridEngine(
    args={},
    model=model,
    config=deepspeed_config,
    config_class=DeepSpeedConfig(deepspeed_config),
)

from transformers import GenerationConfig

with torch.no_grad():
    deepspeed_hybrid_engine.eval()
    print(deepspeed_hybrid_engine.generate(
        torch.tensor([[1]], dtype=torch.int, device=deepspeed_hybrid_engine.device),
        synced_gpus=True,
        generation_config=GenerationConfig(max_new_tokens=20),
    ))

Expected behavior No error

ds_report output

[2024-06-08 00:59:38,246] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3
 [WARNING]  using untested triton version (2.3.0), only 1.0.0 is known to be compatible
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3
 [WARNING]  using untested triton version (2.3.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/nixos/peftai/.venv/lib/python3.11/site-packages/torch']
torch version .................... 2.3.0+cu121
deepspeed install path ........... ['/home/nixos/peftai/.venv/lib/python3.11/site-packages/deepspeed']
deepspeed info ................... 0.14.2, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.2
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 15.67 GB

Screenshots Not applicable

System info (please complete the following information):

Docker context Not using Docker

Additional context

accelerate==0.23.0
aiofiles==23.2.1
aiohttp==3.8.6
aiohttp-cors==0.7.0
aiosignal==1.3.13
annotated-types==0.6.0
anyio==4.3.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.0
async-lru==2.0.4
async-timeout==4.0.3
asyncstdlib==3.10.9
attrs==23.1.0
autoawq==0.2.5
autoawq_kernels==0.0.6
autoflake==2.2.1
azure-cli==2.60.0
Babel==2.14.0
backcall==0.2.0
beautifulsoup4==4.12.2
bitsandbytes==0.43.0
black==24.3.0
bleach==6.1.0
cached_classproperty==1.0.1
cachetools==5.3.1
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.0
click==8.1.7
cloudpickle==3.0.0
cmake==3.29.2
colorful==0.5.6
comm==0.1.4
coverage==7.5.1
cryptography==41.0.4
datasets==2.18.0
debugpy==1.8.1
decorator==5.1.1
deepmerge==2.0b0
deepspeed==0.14.2
defusedxml==0.7.1
dill==0.3.8
diskcache==5.6.3
distlib==0.3.8
distro==1.9.0
ecdsa==0.18.0
einops==0.7.0
executing==2.0.0
fastapi==0.110.0
fastjsonschema==2.18.1
filelock==3.12.4
flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
fqdn==1.5.1
frozenlist==1.4.0
fsspec==2023.9.2
google-api-core==2.8.0
google-auth==2.29.0
googleapis-common-protos==1.56.1
gptcache==0.1.42
grpcio==1.63.0
guidance==0.0.64
h11==0.14.0
hiredis==2.2.3
hjson==3.1.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.19.4
idna==3.4
immutables==0.20
iniconfig==2.0.0
interegular==0.3.3
ipykernel==6.25.2
ipython==8.16.1
ipywidgets==8.1.2
isoduration==20.11.0
isort==5.13.2
jaraco.functools==3.9.0
jedi==0.19.1
Jinja2==3.1.2
joblib==1.3.2
json5==0.9.24
jsonpointer==2.4
jsonschema==4.19.1
jsonschema-specifications==2023.7.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.4
jupyter_client==8.4.0
jupyter_core==5.4.0
jupyter_server==2.13.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.5
jupyterlab-pygments==0.2.2
jupyterlab_server==2.25.4
jupyterlab_widgets==3.0.10
lark==1.1.9
lazy-object-proxy==1.10.0
linkify-it-py==2.0.3
llvmlite==0.42.0
lm-format-enforcer==0.9.8
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib-inline==0.1.6
mdit-py-plugins==0.4.1
mdurl==0.1.2
memray==1.12.0
mistune==3.0.2
more-itertools==9.1.0
mpmath==1.3.0
msal==1.24.1
msgpack==1.0.8
multidict==6.0.4
multiprocess==0.70.16
mypy-extensions==1.0.0
nbclient==0.8.0
nbconvert==7.9.2
nbformat==5.9.2
nbval==0.11.0
nest-asyncio==1.5.8
networkx==3.1
ninja==1.11.1.1
nodeenv==1.8.0
notebook==7.1.2
notebook_shim==0.2.4
numba==0.59.1
numpy==1.26.0
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-ml-py==12.550.52
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.1.105
openai==1.25.2
opencensus==0.11.4
opencensus-context==0.1.3
outlines==0.0.34
overrides==7.7.0
packaging==23.2
pandas==2.2.1
pandocfilters==1.5.0
parso==0.8.3
pathspec==0.12.1
peft==0.5.0
pexpect==4.8.0
pickleshare==0.7.5
platformdirs==3.11.0
pluggy==1.5.0
poetry==1.8.3
pre_commit==3.7.1
prometheus-fastapi-instrumentator==7.0.0
prometheus_client==0.20.0
prompt-toolkit==3.0.39
protobuf==5.26.0
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
py-cord==2.4.1
py-cpuinfo==9.0.0
py-spy==0.3.14
pyarrow==15.0.2
pyarrow-hotfix==0.6
pyasn1==0.5.0
pyasn1_modules==0.4.0
pycparser==2.21
pydantic==2.7.3
pydantic_core==2.18.4
pyflakes==3.1.0
pyflyby==1.9.2
Pygments==2.16.1
pygtrie==2.5.0
PyJWT==2.8.0
pynvml==11.5.0
pyparsing==3.1.1
pyright==1.1.359
PySide6==6.6.3
PySide6_Addons==6.6.3
PySide6_Essentials==6.6.3
pytest==8.2.0
python-dateutil==2.8.2
python-dotenv==1.0.1
python-jose==3.3.0
python-json-logger==2.0.7
python-ulid==1.1.0
pytz==2024.1
pyxll==5.8.0
pyxll_jupyter==0.5.2
PyYAML==6.0.1
pyzmq==25.1.1
qtconsole==5.5.1
QtPy==2.4.1
ray==2.23.0
redis==4.6.0
redis-om==0.3.1
referencing==0.30.2
regex==2023.10.3
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rpds-py==0.10.6
rsa==4.9
safetensors==0.4.2
scipy==1.11.3
Send2Trash==1.8.2
sentencepiece==0.2.0
shiboken6==6.6.3
six==1.16.0
smart-open==7.0.4
sniffio==1.3.1
soupsieve==2.5
stack-data==0.6.3
starlette==0.36.3
sympy==1.12
terminado==0.18.1
textual==0.65.2
tiktoken==0.6.0
tinycss2==1.2.1
tokenizers==0.19.1
toml==0.10.2
torch==2.3.0
tornado==6.3.3
tqdm==4.66.1
traitlets==5.11.2
transformers==4.40.1
triton==2.3.0
typeguard==4.1.5
types-pyOpenSSL==23.2.0.2
types-python-dateutil==2.9.0.20240316
types-redis==4.6.0.7
typing_extensions==4.8.0
tzdata==2024.1
uc-micro-py==1.0.3
uri-template==1.3.0
urllib3==2.0.6
uvicorn==0.29.0
uvloop==0.19.0
virtualenv==20.26.2
vllm==0.4.2
vllm_nccl_cu12==2.18.1.0.4.0
vulnix==1.10.2.dev0
watchfiles==0.21.0
wcwidth==0.2.8
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
websockets==12.0
widgetsnbextension==4.0.10
wrapt==1.16.0
xformers==0.0.26.post1
xxhash==3.4.1
yarl==1.9.2
zstandard==0.22.0
Atry commented 5 months ago

Note that this bug is about DeepSpeed HE, not DeepSpeed Chat. I reported it as deepspeed-chat label because there is not deepspeed-he label

Atry commented 5 months ago

Also note that this bug is only visible when #5398 is fixed, therefore I applied #5624 as a monkey patch to reproduce this bug.