haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
19.39k stars 2.13k forks source link

[Usage] v1.5 task LoRA inference is broken ? #1408

Closed benihime91 closed 5 months ago

benihime91 commented 5 months ago

Describe the issue

Issue:

Hi, I trained a task LoRA on custom dataset using the script under scripts/v1_5/finetune_task_lora.sh. The model has been trained but now i am unable to do any inference.

Command: The following is the final script i used to train the task lora

#!/bin/bash
export TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S")
export WANDB_PROJECT=...

deepspeed --include localhost:1,2 /mnt/data1/ayushman/projects/dash-diffusion/LLaVA/llava/train/train_mem.py \
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
    --deepspeed  ./scripts/zero3.json \
    --model_name_or_path liuhaotian/llava-v1.5-13b \
    --version v1 \
    --data_path ... \
    --image_folder ... \
    --vision_tower openai/clip-vit-large-patch14-336 \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir ./checkpoints/$TIMESTAMP \
    --num_train_epochs 3 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 2e-4 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb

The following is the command i use to launch the model worker

CUDA_VISIBLE_DEVICES=1 CUDA_LAUNCH_BLOCKING=1 python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path ${MODEL_PATH} --model-base lmsys/vicuna-13b-v1.5

I am adding the model_worker logs below

Log:

2024-04-16 02:52:43 | INFO | model_worker | args: Namespace(host='0.0.0.0', port=40000, worker_address='http://localhost:40000', controller_address='http://localhost:10000', model_path='/checkpoints/llava-v1.5-13b-task-lora/2024-04-16_01-38-57/', model_base='lmsys/vicuna-13b-v1.5', model_name=None, device='cuda', multi_modal=False, limit_model_concurrency=5, stream_interval=1, no_register=False, load_8bit=False, load_4bit=False, use_flash_attn=False)
2024-04-16 02:52:43 | INFO | model_worker | Loading the model 2024-04-16_01-38-57 on worker dcf422 ...
2024-04-16 02:52:44 | ERROR | stderr | 
Loading checkpoint shards:   0%|                                                                                        | 0/3 [00:00<?, ?it/s]
2024-04-16 02:52:44 | ERROR | stderr | /home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
2024-04-16 02:52:44 | ERROR | stderr |   return self.fget.__get__(instance, owner)()
2024-04-16 02:52:46 | ERROR | stderr | 
Loading checkpoint shards:  33%|██████████████████████████▋                                                     | 1/3 [00:01<00:03,  1.92s/it]
2024-04-16 02:52:47 | ERROR | stderr | 
Loading checkpoint shards:  67%|█████████████████████████████████████████████████████▎                          | 2/3 [00:03<00:01,  1.50s/it]
2024-04-16 02:52:48 | ERROR | stderr | 
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.14s/it]
2024-04-16 02:52:48 | ERROR | stderr | 
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.28s/it]
2024-04-16 02:52:48 | ERROR | stderr | 
2024-04-16 02:52:48 | ERROR | stderr | /home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
2024-04-16 02:52:48 | ERROR | stderr |   warnings.warn(
2024-04-16 02:52:48 | ERROR | stderr | /home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
2024-04-16 02:52:48 | ERROR | stderr |   warnings.warn(
2024-04-16 02:52:48 | INFO | stdout | Loading LoRA weights from /mnt/data1/ayushman/projects/dash-diffusion/checkpoints/llava-v1.5-13b-task-lora/2024-04-16_01-38-57
2024-04-16 02:52:52 | INFO | stdout | Merging weights
2024-04-16 02:52:53 | INFO | stdout | Convert to FP16...
2024-04-16 02:52:53 | INFO | model_worker | Register to controller
2024-04-16 02:52:53 | ERROR | stderr | INFO:     Started server process [3105957]
2024-04-16 02:52:53 | ERROR | stderr | INFO:     Waiting for application startup.
2024-04-16 02:52:53 | ERROR | stderr | INFO:     Application startup complete.
2024-04-16 02:52:53 | ERROR | stderr | INFO:     Uvicorn running on http://0.0.0.0:40000 (Press CTRL+C to quit)
2024-04-16 02:52:53 | INFO | stdout | INFO:     127.0.0.1:56118 - "POST /worker_get_status HTTP/1.1" 200 OK
2024-04-16 02:53:08 | INFO | model_worker | Send heart beat. Models: ['2024-04-16_01-38-57']. Semaphore: None. global_counter: 0
2024-04-16 02:53:18 | INFO | model_worker | Send heart beat. Models: ['2024-04-16_01-38-57']. Semaphore: Semaphore(value=4, locked=False). global_counter: 1
2024-04-16 02:53:18 | INFO | stdout | INFO:     127.0.0.1:39104 - "POST /worker_generate_stream HTTP/1.1" 200 OK
2024-04-16 02:53:18 | ERROR | stderr | /home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py:1295: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use and modify the model generation configuration (see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )
2024-04-16 02:53:18 | ERROR | stderr |   warnings.warn(
2024-04-16 02:53:19 | ERROR | stderr | Exception in thread Thread-4 (generate):
2024-04-16 02:53:19 | ERROR | stderr | Traceback (most recent call last):
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
2024-04-16 02:53:19 | ERROR | stderr |     self.run()
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/threading.py", line 953, in run
2024-04-16 02:53:19 | ERROR | stderr |     self._target(*self._args, **self._kwargs)
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
2024-04-16 02:53:19 | ERROR | stderr |     return func(*args, **kwargs)
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py", line 1525, in generate
2024-04-16 02:53:19 | ERROR | stderr |     return self.sample(
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py", line 2622, in sample
2024-04-16 02:53:19 | ERROR | stderr |     outputs = self(
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
2024-04-16 02:53:19 | ERROR | stderr |     return self._call_impl(*args, **kwargs)
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
2024-04-16 02:53:19 | ERROR | stderr |     return forward_call(*args, **kwargs)
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
2024-04-16 02:53:19 | ERROR | stderr |     outputs = self.model(
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
2024-04-16 02:53:19 | ERROR | stderr |     return self._call_impl(*args, **kwargs)
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
2024-04-16 02:53:19 | ERROR | stderr |     return forward_call(*args, **kwargs)
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1027, in forward
2024-04-16 02:53:19 | ERROR | stderr |     inputs_embeds = self.embed_tokens(input_ids)
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
2024-04-16 02:53:19 | ERROR | stderr |     return self._call_impl(*args, **kwargs)
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
2024-04-16 02:53:19 | ERROR | stderr |     return forward_call(*args, **kwargs)
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
2024-04-16 02:53:19 | ERROR | stderr |     return F.embedding(
2024-04-16 02:53:19 | ERROR | stderr |   File "/home/ayushman/miniforge3/envs/llava/lib/python3.10/site-packages/torch/nn/functional.py", line 2233, in embedding
2024-04-16 02:53:19 | ERROR | stderr |     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
2024-04-16 02:53:19 | ERROR | stderr | RuntimeError: CUDA error: device-side assert triggered
2024-04-16 02:53:19 | ERROR | stderr | Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
2024-04-16 02:53:19 | ERROR | stderr | 
2024-04-16 02:53:23 | INFO | model_worker | Send heart beat. Models: ['2024-04-16_01-38-57']. Semaphore: Semaphore(value=4, locked=False). global_counter: 1
2024-04-16 02:53:33 | INFO | stdout | Caught Unknown Error
2024-04-16 02:53:33 | INFO | model_worker | Send heart beat. Models: ['2024-04-16_01-38-57']. Semaphore: Semaphore(value=5, locked=False). global_counter: 1
2024-04-16 02:53:38 | INFO | model_worker | Send heart beat. Models: ['2024-04-16_01-38-57']. Semaphore: Semaphore(value=5, locked=False). global_counter: 1

Also adding my env for better visibility

Package                   Version     Editable project location
------------------------- ----------- -------------------------------------------------
accelerate                0.21.0
aiofiles                  23.2.1
albumentations            1.4.3
altair                    5.3.0
annotated-types           0.6.0
anyio                     4.3.0
appdirs                   1.4.4
asttokens                 2.4.1
attrs                     23.2.0
bitsandbytes              0.43.1
blis                      0.7.11
braceexpand               0.1.7
cachetools                5.3.3
catalogue                 2.0.10
certifi                   2024.2.2
charset-normalizer        3.3.2
click                     8.1.7
cloudpathlib              0.16.0
colorama                  0.4.6
comm                      0.2.2
confection                0.1.4
contourpy                 1.2.1
cycler                    0.12.1
cymem                     2.0.8
dataclasses               0.6
debugpy                   1.8.1
decorator                 5.1.1
deepspeed                 0.12.6
docker-pycreds            0.4.0
einops                    0.6.1
einops-exts               0.0.4
exceptiongroup            1.2.0
executing                 2.0.1
ExifRead-nocycle          3.0.1
fastai                    2.7.14
fastapi                   0.110.1
fastcore                  1.5.29
fastdownload              0.0.7
fastprogress              1.0.3
ffmpy                     0.3.2
filelock                  3.13.4
fire                      0.5.0
flash-attn                2.5.7
fonttools                 4.51.0
fsspec                    2024.3.1
gitdb                     4.0.11
GitPython                 3.1.43
gradio                    4.16.0
gradio_client             0.8.1
h11                       0.14.0
hjson                     3.1.0
httpcore                  0.17.3
httpx                     0.24.0
huggingface-hub           0.22.2
idna                      3.7
imageio                   2.34.0
img2dataset               1.45.0
importlib_metadata        7.1.0
importlib_resources       6.4.0
ipykernel                 6.29.4
ipython                   8.23.0
jedi                      0.19.1
Jinja2                    3.1.3
joblib                    1.4.0
jsonschema                4.21.1
jsonschema-specifications 2023.12.1
jupyter_client            8.6.1
jupyter_core              5.7.2
kiwisolver                1.4.5
langcodes                 3.3.0
lazy_loader               0.4
llava                     1.2.2.post1 /mnt/data1/ayushman/projects/dash-diffusion/LLaVA
markdown-it-py            3.0.0
markdown2                 2.4.13
MarkupSafe                2.1.5
matplotlib                3.8.4
matplotlib-inline         0.1.6
mdurl                     0.1.2
mpmath                    1.3.0
murmurhash                1.0.10
nest_asyncio              1.6.0
networkx                  3.3
ninja                     1.11.1.1
numpy                     1.26.4
nvidia-cublas-cu12        12.1.3.1
nvidia-cuda-cupti-cu12    12.1.105
nvidia-cuda-nvrtc-cu12    12.1.105
nvidia-cuda-runtime-cu12  12.1.105
nvidia-cudnn-cu12         8.9.2.26
nvidia-cufft-cu12         11.0.2.54
nvidia-curand-cu12        10.3.2.106
nvidia-cusolver-cu12      11.4.5.107
nvidia-cusparse-cu12      12.1.0.106
nvidia-ml-py              12.535.133
nvidia-nccl-cu12          2.18.1
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu12          12.1.105
nvitop                    1.3.2
opencv-python-headless    4.9.0.80
orjson                    3.10.1
packaging                 24.0
pandas                    2.2.2
parso                     0.8.4
peft                      0.10.0
pexpect                   4.9.0
pickleshare               0.7.5
pillow                    10.3.0
pip                       24.0
platformdirs              4.2.0
preshed                   3.0.9
prompt-toolkit            3.0.43
protobuf                  4.25.3
psutil                    5.9.8
ptyprocess                0.7.0
pure-eval                 0.2.2
py-cpuinfo                9.0.0
pyarrow                   15.0.2
pydantic                  2.7.0
pydantic_core             2.18.1
pydub                     0.25.1
Pygments                  2.17.2
pynvml                    11.5.0
pyparsing                 3.1.2
python-dateutil           2.9.0.post0
python-multipart          0.0.9
pytz                      2024.1
PyYAML                    6.0.1
pyzmq                     26.0.0
referencing               0.34.0
regex                     2023.12.25
requests                  2.31.0
rich                      13.7.1
rpds-py                   0.18.0
ruff                      0.3.7
safetensors               0.4.3
scikit-image              0.23.1
scikit-learn              1.2.2
scipy                     1.13.0
semantic-version          2.10.0
sentencepiece             0.1.99
sentry-sdk                1.45.0
setproctitle              1.3.3
setuptools                69.5.1
shellingham               1.5.4
shortuuid                 1.0.13
six                       1.16.0
smart-open                6.4.0
smmap                     5.0.1
sniffio                   1.3.1
spacy                     3.7.4
spacy-legacy              3.0.12
spacy-loggers             1.0.5
srsly                     2.4.8
stack-data                0.6.3
starlette                 0.37.2
svgwrite                  1.4.3
sympy                     1.12
termcolor                 2.4.0
thinc                     8.2.3
threadpoolctl             3.4.0
tifffile                  2024.2.12
timm                      0.6.13
tokenizers                0.15.1
tomlkit                   0.12.0
toolz                     0.12.1
torch                     2.1.2
torchvision               0.16.2
tornado                   6.4
tqdm                      4.66.2
traitlets                 5.14.2
transformers              4.37.2
triton                    2.1.0
typer                     0.9.4
typing_extensions         4.11.0
tzdata                    2024.1
urllib3                   2.2.1
uvicorn                   0.29.0
wandb                     0.16.6
wasabi                    1.1.2
wavedrom                  2.0.3.post3
wcwidth                   0.2.13
weasel                    0.3.4
webdataset                0.2.86
websockets                11.0.3
wheel                     0.43.0
zipp                      3.17.0

Screenshots: You may attach screenshots if it better explains the issue.

feiyangsuo commented 2 months ago

Hi, did you find the solution or any other way to load the trained lora for inference?