haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
19.57k stars 2.16k forks source link

[Question] inference after first-stage pretraining #1169

Open yinincanada opened 7 months ago

yinincanada commented 7 months ago

Question

using script scripts/v1_5/pretrain.sh, I conducted llava pretraining using some task-specific image-text pair, now I am thinking of testing the pretrained model, and wondering how can I do that?

the pretraining only output the modality aligner (mm_projector.bin) part and I do not see any script which load in base llm vicuna, visual tower, and pretrained mm_projector.bin and do inference. anything to share for this if it is available?

another direction could be, is there a configuration that I can save the whole model after pretrain (like is done after second-stage instructional finetune), with which I can simply leverage the available inference script llava/eval/model_vqa.py?

thanks!

yinincanada commented 7 months ago

Appreciate if anyone give some advice. I have been trying to change the code a bit attempting to save the whole model, not working so far.

Shengcao-Cao commented 7 months ago

Hi @yinincanada ,

I happen to have the same question as yours. I find that load_pretrained_model from llava/model is able to load the model with a pretrained mm_projector.bin (in my case, it is in folder ./save/llava-v1.5-7b-pretrain). Then I use the simple CLI script to test my pretrained project (along with frozen LLM + Vision Encoder) like this:

python -m llava.serve.cli \
    --model-base lmsys/vicuna-7b-v1.5 \
    --model-path ./save/llava-v1.5-7b-pretrain \
    --image-file ./save/sample/cat-and-dog.jpg \
    --temperature 0.0

Hope this helps!

yinincanada commented 7 months ago

@Shengcao-Cao thanks! but I do not see, from what you said, how the vision tower is loaded in for model inference? my understanding is, we need base vicuna, mm_project.bin, and also vision encoder (clip-vit)

Shengcao-Cao commented 7 months ago

You may check the load_pretrained_model function and the vision tower is loaded at the end. The config comes from your pretrain save directory.

Shengcao-Cao commented 7 months ago

BTW I forgot to mention, I modified llava/model/builder.py a bit due to a bug reported here: https://github.com/haotian-liu/LLaVA/issues/1075.

My updated code between Line 96 and Line 100:

                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
                from llava.model.language_model.llava_llama import LlavaConfig
                cfg_pretrained = LlavaConfig.from_pretrained(model_path)
                model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)

In my case, cfg_pretrained cannot be correctly loaded from model_path if it is AutoConfig, so I follow https://github.com/haotian-liu/LLaVA/commit/04fb03d4943ed212c24381e44c325525e700884d and load the config with LlavaConfig as well.

yinincanada commented 7 months ago

Thanks! @Shengcao-Cao

conheaven commented 7 months ago

Hi @yinincanada , 你好,

I happen to have the same question as yours. I find that load_pretrained_model from llava/model is able to load the model with a pretrained mm_projector.bin (in my case, it is in folder ./save/llava-v1.5-7b-pretrain). Then I use the simple CLI script to test my pretrained project (along with frozen LLM + Vision Encoder) like this:我碰巧也有和你一样的问题。我发现from能够用预训练的模型加载(在我的情况下,它在文件夹中)。然后,我使用简单的CLI脚本来测试我的预训练项目(连同冻结LLM视觉编码器),如下所示:

python -m llava.serve.cli \
    --model-base lmsys/vicuna-7b-v1.5 \
    --model-path ./save/llava-v1.5-7b-pretrain \
    --image-file ./save/sample/cat-and-dog.jpg \
    --temperature 0.0

Hope this helps! 希望这对你有帮助!

after finetune with finetune_lora.sh, I merge the model with 'python scripts/merge_lora_weights.py --model-path /data1/khw/output_llava/finetune/llava-v1.5-13b-lora --model-base /data1/khw/llava --save-model-path /data1/khw/output_llava/merge_model-lora' then i run 'python -m llava.serve.cli --model-base /data1/khw/llava --model-path /data1/khw/output_llava/merge_model-lora --image-file /data1/khw/img/1.jpg ' it shows "Traceback (most recent call last): File "/home/khw/miniconda3/envs/llava/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/khw/miniconda3/envs/llava/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/khw/llava/LLaVA/llava/serve/cli.py", line 126, in main(args) File "/home/khw/llava/LLaVA/llava/serve/cli.py", line 32, in main tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) File "/home/khw/llava/LLaVA/llava/model/builder.py", line 122, in load_pretrained_model model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) File "/home/khw/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained raise ValueError( ValueError: Unrecognized configuration class <class 'transformers.models.llava.configuration_llava.LlavaConfig'> for this kind of AutoModel: AutoModelForCausalLM. Model type should be one of BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, LlamaConfig, CodeGenConfig, CpmAntConfig, CTRLConfig, Data2VecTextConfig, ElectraConfig, ErnieConfig, FalconConfig, FuyuConfig, GemmaConfig, GitConfig, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, LlamaConfig, MarianConfig, MBartConfig, MegaConfig, MegatronBertConfig, MistralConfig, MixtralConfig, MptConfig, MusicgenConfig, MvpConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, PegasusConfig, PersimmonConfig, PhiConfig, PLBartConfig, ProphetNetConfig, QDQBertConfig, Qwen2Config, ReformerConfig, RemBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, RwkvConfig, Speech2Text2Config, StableLmConfig, TransfoXLConfig, TrOCRConfig, WhisperConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig, LlavaConfig, LlavaMptConfig, LlavaMistralConfig." thank you in advance