Efficient-Large-Model / VILA

VILA - a multi-image visual language model with training, inference and evaluation recipe, deployable from cloud to edge (Jetson Orin and laptops)
Apache License 2.0
877 stars 55 forks source link

Llama-3-VILA1.5-8B Inference error #39

Open joebradly opened 1 month ago

joebradly commented 1 month ago

Hello! Thanks for sharing such a nice project. I have set up environment following the instructions in ReadME. When I run the inference example as the following ( i have copy the run_vila.py file from llava/eval/ to the current project root): '''bash python run_vila.py \ --model-path Efficient-Large_model/Llama-3-VILA1.5-8B \ --conv-mode vicuna_v1 \ --query "\n Please describe the traffic condition." \ --image-file "./demo_images/av.png" ''' I encounter the following error: ''' ['./demo_images/av.png']

Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s] Loading checkpoint shards: 25%|██▌ | 1/4 [01:46<05:18, 106.09s/it] Loading checkpoint shards: 50%|█████ | 2/4 [03:47<03:49, 114.88s/it] Loading checkpoint shards: 75%|███████▌ | 3/4 [05:02<01:37, 97.03s/it] Loading checkpoint shards: 100%|██████████| 4/4 [05:13<00:00, 62.85s/it] Loading checkpoint shards: 100%|██████████| 4/4 [05:13<00:00, 78.34s/it] Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results. Setting pad_token_id to eos_token_id:128001 for open-end generation. input: \n Please describe the traffic condition. [WARNING] the auto inferred conversation mode is llava_v0, while --conv-mode is vicuna_v1, using vicuna_v1 torch.Size([1, 3, 384, 384]) Traceback (most recent call last): File "/home/deping.zhang/code/llm/VILA/run_vila.py", line 153, in eval_model(args) File "/home/deping.zhang/code/llm/VILA/run_vila.py", line 115, in eval_model output_ids = model.generate( File "/home/deping.zhang/.conda/envs/vila/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "/home/deping.zhang/code/llm/VILA/llava/model/language_model/llava_llama.py", line 171, in generate outputs = self.llm.generate( File "/home/deping.zhang/.conda/envs/vila/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(args, **kwargs) File "/home/deping.zhang/.conda/envs/vila/lib/python3.10/site-packages/transformers/generation/utils.py", line 1764, in generate return self.sample( File "/home/deping.zhang/.conda/envs/vila/lib/python3.10/site-packages/transformers/generation/utils.py", line 2924, in sample if stopping_criteria(input_ids, scores): File "/home/deping.zhang/.conda/envs/vila/lib/python3.10/site-packages/transformers/generation/stopping_criteria.py", line 132, in call return any(criteria(input_ids, scores) for criteria in self) File "/home/deping.zhang/.conda/envs/vila/lib/python3.10/site-packages/transformers/generation/stopping_criteria.py", line 132, in return any(criteria(input_ids, scores) for criteria in self) File "/home/deping.zhang/code/llm/VILA/llava/mm_utils.py", line 287, in call outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) File "/home/deping.zhang/code/llm/VILA/llava/mm_utils.py", line 272, in call_for_batch if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0 '''

Lyken17 commented 1 month ago

could @joebradly @seancraven314 share your environemnt? The code runs without error on myside.

SeanCraven314 commented 1 month ago

Hi this is a dump of my environment,

I am launching from the cli:

python llava/eval/run_vila.py \
  --model-path=Efficient-Large-Model/Llama-3-VILA1.5-8B \
  --image-file=test.jpg\
  --query "What is this?"

Pip list

accelerate                0.27.2
aiofiles                  23.2.1
aiohttp                   3.9.5
aiosignal                 1.3.1
altair                    5.3.0
anyio                     4.3.0
async-timeout             4.0.3
attrs                     23.2.0
av                        12.0.0
bitsandbytes              0.41.0
braceexpand               0.1.7
certifi                   2024.2.2
charset-normalizer        3.3.2
click                     8.1.7
cmake                     3.29.2
contourpy                 1.2.1
cycler                    0.12.1
datasets                  2.16.1
decord                    0.6.0
dill                      0.3.7
distro                    1.9.0
dnspython                 2.6.1
einops                    0.6.1
einops-exts               0.0.4
email_validator           2.1.1
et-xmlfile                1.1.0
exceptiongroup            1.2.1
fastapi                   0.111.0
fastapi-cli               0.0.2
ffmpy                     0.3.2
filelock                  3.14.0
flash-attn                2.4.2
fonttools                 4.51.0
frozenlist                1.4.1
fsspec                    2023.10.0
fvcore                    0.1.5.post20221221
gradio                    3.35.2
gradio_client             0.2.9
h11                       0.14.0
hf_transfer               0.1.6
httpcore                  0.17.3
httptools                 0.6.1
httpx                     0.24.0
huggingface-hub           0.23.0
idna                      3.7
iopath                    0.1.10
Jinja2                    3.1.4
joblib                    1.4.2
jsonschema                4.22.0
jsonschema-specifications 2023.12.1
kiwisolver                1.4.5
linkify-it-py             2.0.3
lit                       18.1.4
markdown-it-py            2.2.0
markdown2                 2.4.13
MarkupSafe                2.1.5
matplotlib                3.8.4
mdit-py-plugins           0.3.3
mdurl                     0.1.2
mpmath                    1.3.0
multidict                 6.0.5
multiprocess              0.70.15
networkx                  3.3
ninja                     1.11.1.1
nltk                      3.3
numpy                     1.26.4
nvidia-cublas-cu11        11.10.3.66
nvidia-cublas-cu12        12.1.3.1
nvidia-cuda-cupti-cu11    11.7.101
nvidia-cuda-cupti-cu12    12.1.105
nvidia-cuda-nvrtc-cu11    11.7.99
nvidia-cuda-nvrtc-cu12    12.1.105
nvidia-cuda-runtime-cu11  11.7.99
nvidia-cuda-runtime-cu12  12.1.105
nvidia-cudnn-cu11         8.5.0.96
nvidia-cudnn-cu12         8.9.2.26
nvidia-cufft-cu11         10.9.0.58
nvidia-cufft-cu12         11.0.2.54
nvidia-curand-cu11        10.2.10.91
nvidia-curand-cu12        10.3.2.106
nvidia-cusolver-cu11      11.4.0.1
nvidia-cusolver-cu12      11.4.5.107
nvidia-cusparse-cu11      11.7.4.91
nvidia-cusparse-cu12      12.1.0.106
nvidia-nccl-cu11          2.14.3
nvidia-nccl-cu12          2.20.5
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu11          11.7.91
nvidia-nvtx-cu12          12.1.105
openai                    1.8.0
opencv-python             4.8.0.74
openpyxl                  3.1.2
orjson                    3.10.3
packaging                 24.0
pandas                    2.2.2
parameterized             0.9.0
peft                      0.5.0
pillow                    10.3.0
pip                       24.0
portalocker               2.8.2
psutil                    5.9.8
pyarrow                   16.0.0
pyarrow-hotfix            0.6
pydantic                  1.10.15
pydub                     0.25.1
Pygments                  2.18.0
pyparsing                 3.1.2
python-dateutil           2.9.0.post0
python-dotenv             1.0.1
python-multipart          0.0.9
pytorchvideo              0.1.5
pytz                      2024.1
pywsd                     1.2.4
PyYAML                    6.0.1
referencing               0.35.1
regex                     2024.4.28
requests                  2.31.0
rich                      13.7.1
rpds-py                   0.18.1
s2wrapper                 0.1
safetensors               0.4.3
scikit-learn              1.2.2
scipy                     1.13.0
semantic-version          2.10.0
sentencepiece             0.1.99
setuptools                68.2.2
shellingham               1.5.4
shortuuid                 1.0.13
six                       1.16.0
sniffio                   1.3.1
starlette                 0.37.2
svgwrite                  1.4.3
sympy                     1.12
tabulate                  0.9.0
termcolor                 2.4.0
threadpoolctl             3.5.0
timm                      0.9.12
tokenizers                0.15.2
tomli                     2.0.1
toolz                     0.12.1
torch                     2.0.1
torchvision               0.15.2
tqdm                      4.66.4
transformers              4.36.2
triton                    2.0.0
typer                     0.12.3
typing_extensions         4.11.0
tzdata                    2024.1
uc-micro-py               1.0.3
ujson                     5.9.0
urllib3                   2.2.1
uvicorn                   0.29.0
uvloop                    0.19.0
vila                      1.0.0              /home/sean-craven/VILA
watchfiles                0.21.0
wavedrom                  2.0.3.post3
webdataset                0.2.86
websockets                12.0
wheel                     0.43.0
wn                        0.9.5
xxhash                    3.4.1
yacs                      0.1.8
yarl                      1.9.4

Running on intel and A100 on Ubuntu 22.04

joebradly commented 1 month ago

could @joebradly @SeanCraven314 share your environemnt? The code runs without error on myside.

Package Version Editable project location


accelerate 0.27.2 aiofiles 23.2.1 aiohttp 3.9.5 aiosignal 1.3.1 altair 5.3.0 anyio 4.3.0 async-timeout 4.0.3 attrs 23.2.0 av 12.0.0 bitsandbytes 0.41.0 braceexpand 0.1.7 certifi 2024.2.2 charset-normalizer 3.3.2 click 8.1.7 cmake 3.29.2 contourpy 1.2.1 cycler 0.12.1 datasets 2.16.1 decord 0.6.0 dill 0.3.7 distro 1.9.0 dnspython 2.6.1 einops 0.6.1 einops-exts 0.0.4 email_validator 2.1.1 et-xmlfile 1.1.0 exceptiongroup 1.2.1 fastapi 0.111.0 fastapi-cli 0.0.2 ffmpy 0.3.2 filelock 3.14.0 flash-attn 2.4.2 fonttools 4.51.0 frozenlist 1.4.1 fsspec 2023.10.0 fvcore 0.1.5.post20221221 gradio 3.35.2 gradio_client 0.2.9 h11 0.14.0 httpcore 0.17.3 httptools 0.6.1 httpx 0.24.0 huggingface-hub 0.23.0 idna 3.7 iopath 0.1.10 Jinja2 3.1.4 joblib 1.4.2 jsonschema 4.22.0 jsonschema-specifications 2023.12.1 kiwisolver 1.4.5 linkify-it-py 2.0.3 lit 18.1.4 markdown-it-py 2.2.0 markdown2 2.4.13 MarkupSafe 2.1.5 matplotlib 3.8.4 mdit-py-plugins 0.3.3 mdurl 0.1.2 mpmath 1.3.0 multidict 6.0.5 multiprocess 0.70.15 networkx 3.3 ninja 1.11.1.1 nltk 3.3 numpy 1.26.4 nvidia-cublas-cu11 11.10.3.66 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu11 11.7.101 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu11 11.7.99 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu11 11.7.99 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu11 8.5.0.96 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu11 10.9.0.58 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu11 10.2.10.91 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu11 11.4.0.1 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu11 11.7.4.91 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu11 2.14.3 nvidia-nccl-cu12 2.20.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu11 11.7.91 nvidia-nvtx-cu12 12.1.105 openai 1.8.0 opencv-python 4.8.0.74 openpyxl 3.1.2 orjson 3.10.3 packaging 24.0 pandas 2.2.2 parameterized 0.9.0 peft 0.5.0 pillow 10.3.0 pip 24.0 portalocker 2.8.2 psutil 5.9.8 pyarrow 16.0.0 pyarrow-hotfix 0.6 pydantic 1.10.15 pydub 0.25.1 Pygments 2.18.0 pyparsing 3.1.2 python-dateutil 2.9.0.post0 python-dotenv 1.0.1 python-multipart 0.0.9 pytorchvideo 0.1.5 pytz 2024.1 pywsd 1.2.4 PyYAML 6.0.1 referencing 0.35.1 regex 2024.4.28 requests 2.31.0 rich 13.7.1 rpds-py 0.18.0 s2wrapper 0.1 safetensors 0.4.3 scikit-learn 1.2.2 scipy 1.13.0 semantic-version 2.10.0 sentencepiece 0.1.99 setuptools 69.5.1 shellingham 1.5.4 shortuuid 1.0.13 six 1.16.0 sniffio 1.3.1 starlette 0.37.2 svgwrite 1.4.3 sympy 1.12 tabulate 0.9.0 termcolor 2.4.0 threadpoolctl 3.5.0 timm 0.9.12 tokenizers 0.15.2 tomli 2.0.1 toolz 0.12.1 torch 2.0.1 torchvision 0.15.2 tqdm 4.66.4 transformers 4.36.2 triton 2.0.0 typer 0.12.3 typing_extensions 4.11.0 tzdata 2024.1 uc-micro-py 1.0.3 ujson 5.9.0 urllib3 2.2.1 uvicorn 0.29.0 uvloop 0.19.0 vila 1.0.0 /home/deping.zhang/code/llm/VILA watchfiles 0.21.0 wavedrom 2.0.3.post3 webdataset 0.2.86 websockets 12.0 wheel 0.43.0 wn 0.9.5 xxhash 3.4.1 yacs 0.1.8 yarl 1.9.4

joebradly commented 1 month ago

I change line 272 to the following: if (output_ids[0, -keyword_id.shape[0] :, None] == keyword_id).all(): return True Then the inference runs through.

gaodianzhuo commented 1 month ago

感谢楼上

Efficient-Large-Language-Model commented 1 month ago

Will verify and fix.

BTW, you need to use --conv-mode=llama_3 w/ llama3 model.

Efficient-Large-Language-Model commented 1 month ago

It seems when using the correct conv mode, there is no issue. Therefore, no code change is needed.

SeanCraven314 commented 1 month ago

Thanks very much for this. Sorry for the hassle.

hkunzhe commented 1 month ago

It seems when using the correct conv mode, there is no issue. Therefore, no code change is needed.

Hi, run_vila.py with VILA1.5-40B (not llama-3) will encounter the same issue. Use the workaround from @joebradly will fix it.

Efficient-Large-Language-Model commented 1 month ago

For VILA1.5-40B, you should use --conv-mode hermes-2

tp-nan commented 1 month ago

hi, for the new version, python3 -W ignore llava/eval/run_vila.py --model-path Efficient-Large-Model/VILA1.5-3B \ --conv-mode vicuna_v1 --query "<image>\n Please describe the traffic condition." \ --image-file "demo_images/av.png" gives ValueError: Keyword tensor should have 2 or 3 dimensions, got 1

How can I fix it?