ymcui / Chinese-LLaMA-Alpaca-2

中文LLaMA-2 & Alpaca-2大模型二期项目 + 64K超长上下文模型 (Chinese LLaMA-2 & Alpaca-2 LLMs with 64K long context models)
Apache License 2.0
7.04k stars 581 forks source link

模型推理的时候报错了,不知道是哪里的问题 #407

Closed 459737087 closed 9 months ago

459737087 commented 9 months ago

提交前必须检查以下项目

问题类型

模型推理

基础模型

Chinese-LLaMA-2 (7B/13B)

操作系统

Linux

详细描述问题

Traceback (most recent call last):
  File "/data/code/Chinese-LLaMA-Alpaca-2/scripts/inference/inference_hf.py", line 187, in <module>
    generation_output = model.generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1648, in generate
    return self.sample(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2730, in sample
    outputs = self(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 820, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 708, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 424, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/data/code/Chinese-LLaMA-Alpaca-2/scripts/attn_and_long_ctx_patches.py", line 46, in xformers_forward
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
RuntimeError: shape '[1, 41, 64, 128]' is invalid for input of size 41984

依赖情况(代码类问题务必提供)

absl-py                 1.4.0
accelerate              0.21.0
aiohttp                 3.8.5
aiosignal               1.3.1
anyio                   3.7.1
appdirs                 1.4.4
argcomplete             3.1.1
arrow                   1.2.3
async-timeout           4.0.2
attrs                   23.1.0
backoff                 2.2.1
beautifulsoup4          4.12.2
bitsandbytes            0.41.0
blinker                 1.4
cachetools              5.3.1
certifi                 2022.12.7
charset-normalizer      2.1.1
click                   8.1.6
collie-lm               1.0.3
coolname                2.2.0
cryptography            3.4.8
datasets                2.14.0
dbus-python             1.2.18
decorator               5.1.1
deepspeed               0.10.1+unknown
dill                    0.3.7
distro                  1.7.0
docker                  6.1.3
docker-pycreds          0.4.0
einops                  0.6.1
exceptiongroup          1.1.2
fastapi                 0.100.1
filelock                3.9.0
flash-attn              2.0.4
frozenlist              1.4.0
fsspec                  2023.4.0
gitdb                   4.0.10
GitPython               3.1.32
google                  3.0.0
google-auth             2.22.0
google-auth-oauthlib    1.0.0
gql                     3.4.1
graphql-core            3.2.3
grpcio                  1.56.2
hjson                   3.1.0
httplib2                0.20.2
huggingface-hub         0.16.4
idna                    3.4
importlib-metadata      6.8.0
jeepney                 0.7.1
Jinja2                  3.1.2
keyring                 23.5.0
launchpadlib            1.10.16
lazr.restfulclient      0.14.4
lazr.uri                1.0.6
lightning-utilities     0.9.0
Markdown                3.4.4
markdown-it-py          3.0.0
MarkupSafe              2.1.2
mdurl                   0.1.2
megatron-core           0.2.0
more-itertools          8.10.0
mosaicml                0.15.1
mosaicml-cli            0.4.17
mpmath                  1.2.1
multidict               6.0.4
multiprocess            0.70.15
networkx                3.0rc1
ninja                   1.11.1
numpy                   1.24.1
oauthlib                3.2.0
packaging               22.0
pandas                  2.0.3
pathtools               0.1.2
peft                    0.4.0
Pillow                  9.3.0
pip                     22.0.2
prompt-toolkit          3.0.39
protobuf                3.20.1
psutil                  5.9.5
py-cpuinfo              9.0.0
pyarrow                 12.0.1
pyasn1                  0.5.0
pyasn1-modules          0.3.0
pydantic                1.10.12
Pygments                2.15.1
PyGObject               3.42.1
PyJWT                   2.3.0
pyparsing               2.4.7
python-apt              2.4.0+ubuntu1
python-dateutil         2.8.2
pytorch-ranger          0.1.1
pytorch-triton          2.1.0+9e3e10c5ed
pytz                    2023.3
PyYAML                  6.0.1
questionary             1.10.0
regex                   2023.6.3
requests                2.28.1
requests-oauthlib       1.3.1
rich                    13.4.2
rsa                     4.9
ruamel.yaml             0.17.32
ruamel.yaml.clib        0.2.7
safetensors             0.3.1
scipy                   1.11.1
SecretStorage           3.3.1
sentencepiece           0.1.99
sentry-sdk              1.28.1
setproctitle            1.3.2
setuptools              59.6.0
six                     1.16.0
smmap                   5.0.0
sniffio                 1.3.0
soupsieve               2.4.1
starlette               0.27.0
sympy                   1.11.1
tabulate                0.9.0
tensorboard             2.13.0
tensorboard-data-server 0.7.1
tokenizers              0.13.3
torch                   2.1.0.dev20230725+cu121
torch-optimizer         0.3.0
torchaudio              2.1.0.dev20230725+cu121
torchmetrics            1.0.1
torchvision             0.16.0.dev20230725+cu121
tqdm                    4.65.0
transformers            4.31.0
typing_extensions       4.7.1
tzdata                  2023.3
urllib3                 1.26.13
validators              0.20.0
wadllib                 1.3.6
wandb                   0.15.7
wcwidth                 0.2.6
websocket-client        1.6.1
websockets              10.4
Werkzeug                2.3.6
wheel                   0.37.1
xxhash                  3.2.0
yarl                    1.9.2
zipp                    1.0.0

运行日志或截图

python3 scripts/inference/inference_hf.py  --base_model llama-70B --with_prompt --interactive
iMountTai commented 9 months ago

pull最新代码并按照requirements安装最新的依赖试试

459737087 commented 9 months ago

我觉得不是这个问题,因为7b,13b都正常使用。70B就报错了 @iMountTai

ymcui commented 9 months ago

本项目未提供70B模型,有问题的话还请自行debug吧。 看最后一行错误,可以考虑围绕xformers或者改为flash_attention启动。 显存够的话也可以不用任何patch加载(代码自行修改)。

459737087 commented 9 months ago

失败了,我也不知道为什么inference就不行,只能用别的方法做了