QwenLM / Qwen-VL

The official repo of Qwen-VL (通义千问-VL) chat & pretrained large vision language model proposed by Alibaba Cloud.
Other
4.27k stars 327 forks source link

[BUG] TypeError: isin() received an invalid combination of arguments. #406

Open fang-d opened 1 month ago

fang-d commented 1 month ago

是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this?

该问题是否在FAQ中有解答? | Is there an existing answer for this in FAQ?

当前行为 | Current Behavior

在WebUI中进行任何输入(例如“Hello!”),程序在运行时报错且无输出。错误日志如下。 If you input anything in the WebUI (e.g. "Hello!"), the program reports the following error without any output. The error log is shown below.

(Qwen) fd@fd:~/makeover/Qwen-VL$ python web_demo_mm.py --checkpoint-path ~/.cache/huggingface/hub/models--Qwen--Qwen-VL-Chat/snapshots/f57cfbd358cb56b710d963669ad1bcfb44cdcdd8/
2024-06-03 02:46:51,324 - modelscope - INFO - PyTorch version 2.3.0 Found.
2024-06-03 02:46:51,324 - modelscope - INFO - Loading ast index from /home/fd/.cache/modelscope/ast_indexer
2024-06-03 02:46:51,335 - modelscope - INFO - Loading done! Current index file version is 1.14.0, with md5 44b8eafcb244e1a7318e2afee9a18c75 and a total number of 976 components indexed
The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Loading checkpoint shards: 100%|████████████████████████████████| 10/10 [00:02<00:00,  3.94it/s]
Running on local URL:  http://127.0.0.1:8000

To create a public link, set `share=True` in `launch()`.
User: Hello!
Traceback (most recent call last):
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/gradio/queueing.py", line 521, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/gradio/route_utils.py", line 276, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/gradio/blocks.py", line 1945, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/gradio/blocks.py", line 1525, in call_function
    prediction = await utils.async_iteration(iterator)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/gradio/utils.py", line 655, in async_iteration
    return await iterator.__anext__()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/gradio/utils.py", line 648, in __anext__
    return await anyio.to_thread.run_sync(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2144, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 851, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/gradio/utils.py", line 631, in run_sync_iterator_async
    return next(iterator)
           ^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/gradio/utils.py", line 814, in gen_wrapper
    response = next(iterator)
               ^^^^^^^^^^^^^^
  File "/home/fd/makeover/Qwen-VL/web_demo_mm.py", line 131, in predict
    for response in model.chat_stream(tokenizer, message, history=history):
  File "/home/fd/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 1021, in stream_generator
    for token in self.generate_stream(
                 ^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/transformers_stream_generator/main.py", line 208, in generate
    ] = self._prepare_attention_mask_for_generation(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fd/.miniconda/envs/Qwen/lib/python3.11/site-packages/transformers/generation/utils.py", line 473, in _prepare_attention_mask_for_generation
    torch.isin(elements=inputs, test_elements=pad_token_id).any()
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: isin() received an invalid combination of arguments - got (test_elements=int, elements=Tensor, ), but expected one of:
 * (Tensor elements, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
 * (Number element, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
 * (Tensor elements, Number test_element, *, bool assume_unique, bool invert, Tensor out)

WebUI

期望行为 | Expected Behavior

WebUI能够不报错并给出正常输出。 WebUI can provide normal output without this error.

复现方法 | Steps To Reproduce

使用conda安装环境。我自己整理的environment.yaml文件如下所示: Use conda to install the environment. The environment.yaml file I made myself is listed below:

name: Qwen

channels:
  - pytorch
  - nvidia
  - conda-forge

dependencies:
  - python=3.11

  # PyTorch
  - pytorch>=2.3
  - torchvision
  - pytorch-cuda>=12.1

  # Other Packages
  - accelerate
  - transformers>=4.32
  - tiktoken
  - einops
  - scipy
  - matplotlib
  - tqdm

  # [Optional] Int-8 Quantization
  - optimum

  # [Optional] Web Demo
  - gradio

  # [Optional] OpenAI API
  - openai
  - fastapi
  - uvicorn
  - pydantic

  # PyPi Packages
  - pip
  - pip:
    - transformers_stream_generator>=0.0.4
    - sse_starlette   # [Optional] OpenAI API
    - huggingface-cli # [Optional] Hugging Face CLI
    - auto-gptq       # [Optional] Int-8 Quantization
    - modelscope

  # [Optional] Code Quality
  - black
  - mypy

运行web_demo_mm.py并与其对话。 Run the web_demo_mm.py and input anything.

运行环境 | Environment

The other installed packages are listed below:

Package                       Version      |  Package                       Version
----------------------------- -----------  |  ----------------------------- -----------
accelerate                    0.30.1       |  nvitop                        1.3.2
addict                        2.4.0        |  openai                        1.30.5
aiofiles                      23.2.1       |  optimum                       1.12.0
aiohttp                       3.9.5        |  orjson                        3.10.3
aiosignal                     1.3.1        |  oss2                          2.18.5
aliyun-python-sdk-core        2.15.1       |  packaging                     24.0
aliyun-python-sdk-kms         2.16.3       |  pandas                        2.2.2
altair                        5.3.0        |  parso                         0.8.3
annotated-types               0.7.0        |  pathspec                      0.12.1
anyio                         4.3.0        |  peft                          0.11.1
asttokens                     2.0.5        |  pexpect                       4.8.0
attrs                         23.2.0       |  pillow                        10.3.0
auto_gptq                     0.7.1        |  Pillow-SIMD                   9.0.0.post1
black                         24.4.2       |  pip                           24.0
Brotli                        1.1.0        |  pkgutil_resolve_name          1.3.10
cachetools                    5.3.3        |  platformdirs                  4.2.2
certifi                       2024.2.2     |  ply                           3.11
cffi                          1.16.0       |  prompt-toolkit                3.0.43
charset-normalizer            3.3.2        |  protobuf                      4.25.3
click                         8.1.7        |  psutil                        5.9.8
colorama                      0.4.6        |  ptyprocess                    0.7.0
coloredlogs                   15.0.1       |  pure-eval                     0.2.2
comm                          0.2.1        |  pyarrow                       16.1.0
contourpy                     1.2.1        |  pyarrow-hotfix                0.6
crcmod                        1.7          |  pycparser                     2.22
cryptography                  42.0.7       |  pycryptodome                  3.20.0
cycler                        0.12.1       |  pydantic                      2.7.2
datasets                      2.18.0       |  pydantic_core                 2.18.3
debugpy                       1.6.7        |  pydub                         0.25.1
decorator                     5.1.1        |  Pygments                      2.18.0
dill                          0.3.8        |  pyparsing                     3.1.2
distro                        1.9.0        |  PyQt5                         5.15.9
dnspython                     2.6.1        |  PyQt5-sip                     12.12.2
einops                        0.8.0        |  PySocks                       1.7.1
email_validator               2.1.1        |  python-dateutil               2.9.0
exceptiongroup                1.2.0        |  python-multipart              0.0.9
executing                     0.8.3        |  pytz                          2024.1
fastapi                       0.111.0      |  PyYAML                        6.0.1
fastapi-cli                   0.0.4        |  pyzmq                         25.1.2
ffmpy                         0.3.0        |  referencing                   0.35.1
filelock                      3.14.0       |  regex                         2024.5.15
fonttools                     4.53.0       |  requests                      2.32.3
frozenlist                    1.4.1        |  rich                          13.7.1
fsspec                        2024.2.0     |  rouge                         1.0.1
gast                          0.5.4        |  rpds-py                       0.18.1
gekko                         1.1.1        |  ruff                          0.4.7
gmpy2                         2.1.5        |  safetensors                   0.4.3
gradio                        4.32.2       |  scipy                         1.13.1
gradio_client                 0.17.0       |  semantic-version              2.10.0
h11                           0.14.0       |  sentencepiece                 0.1.99
h2                            4.1.0        |  setuptools                    70.0.0
hpack                         4.0.0        |  shellingham                   1.5.4
httpcore                      1.0.5        |  simplejson                    3.19.2
httpx                         0.27.0       |  sip                           6.7.12
huggingface-cli               0.1          |  six                           1.16.0
huggingface_hub               0.23.2       |  sniffio                       1.3.1
humanfriendly                 10.0         |  sortedcontainers              2.4.0
hyperframe                    6.0.1        |  sse-starlette                 2.1.0
idna                          3.7          |  stack-data                    0.2.0
importlib_metadata            7.1.0        |  starlette                     0.37.2
importlib_resources           6.4.0        |  sympy                         1.12
ipykernel                     6.28.0       |  termcolor                     2.4.0
ipython                       8.20.0       |  tiktoken                      0.7.0
ipywidgets                    8.1.2        |  tokenizers                    0.19.1
jedi                          0.18.1       |  toml                          0.10.2
Jinja2                        3.1.4        |  tomli                         2.0.1
jmespath                      0.10.0       |  tomlkit                       0.12.0
jsonschema                    4.22.0       |  toolz                         0.12.1
jsonschema-specifications     2023.12.1    |  torch                         2.3.0
jupyter_client                8.6.0        |  torchvision                   0.18.0
jupyter_core                  5.5.0        |  tornado                       6.4
jupyterlab-widgets            3.0.10       |  tqdm                          4.66.4
kiwisolver                    1.4.5        |  traitlets                     5.7.1
latex2mathml                  3.77.0       |  transformers                  4.41.2
Markdown                      3.6          |  transformers-stream-generator 0.0.5
markdown-it-py                3.0.0        |  triton                        2.3.0
MarkupSafe                    2.1.5        |  typer                         0.12.3
matplotlib                    3.8.4        |  typer-slim                    0.12.3
matplotlib-inline             0.1.6        |  typing_extensions             4.12.1
mdtex2html                    1.3.0        |  tzdata                        2024.1
mdurl                         0.1.2        |  ujson                         5.10.0
modelscope                    1.14.0       |  urllib3                       2.2.1
mpmath                        1.3.0        |  uvicorn                       0.30.0
multidict                     6.0.5        |  wcwidth                       0.2.5
multiprocess                  0.70.16      |  websockets                    11.0.3
munkres                       1.1.4        |  wheel                         0.43.0
mypy                          1.10.0       |  widgetsnbextension            4.0.10
mypy-extensions               1.0.0        |  xxhash                        3.4.1
nest-asyncio                  1.6.0        |  yapf                          0.40.2
networkx                      3.3          |  yarl                          1.9.4
numpy                         1.26.4       |  zipp                          3.17.0
nvidia-ml-py                  12.535.161   |

初步解决方案 | Preliminary Solution

web_demo_mm.py文件中,添加下面的代码可以初步解决这个问题。如果社区认为这个方案可行,我很愿意提供pull request. In the web_demo_mm.py file, adding the following code can preliminarily solve this problem. If the community thinks this solution good enough, I am willing to provide a pull request.

def _load_model_tokenizer(args):
    tokenizer = AutoTokenizer.from_pretrained(
        args.checkpoint_path, trust_remote_code=True, resume_download=True, revision='master',
    )

    if args.cpu_only:
        device_map = "cpu"
    else:
        device_map = "cuda"

    model = AutoModelForCausalLM.from_pretrained(
        args.checkpoint_path,
        device_map=device_map,
        trust_remote_code=True,
        resume_download=True,
        revision='master',
    ).eval()
    model.generation_config = GenerationConfig.from_pretrained(
        args.checkpoint_path, trust_remote_code=True, resume_download=True, revision='master',
    )

+   if model.generation_config.pad_token_id is not None:
+       model.generation_config.pad_token_id = torch.tensor(
+           [model.generation_config.pad_token_id], device=model.device
+       )
+   if model.generation_config.eos_token_id is not None:
+       model.generation_config.eos_token_id = torch.tensor(
+           [model.generation_config.eos_token_id], device=model.device
+       )

    return model, tokenizer