bfshi / scaling_on_scales

When do we not need larger vision models?
MIT License
321 stars 9 forks source link

How to train Llava with S^2? #12

Open Chloe1997 opened 3 months ago

Chloe1997 commented 3 months ago

Hi! Your work is great.. These day, I want to start finetuning Llava + S^2 wrapper with pretrained Llava from https://huggingface.co/liuhaotian/llava-v1.5-7b. However, I was struggling with the mm_hidden_size of pretrained Lllava projector. According o the following snippet, the error occured when the S^2 wrapper set the projector with the size of 3072 while the pretrained projector is 1024. I have tried to download the pretrained S^2 projector you provided and revise the mm_hidden_size in the config.json. Do you have any suggestion? Thanks you.

model = LlavaLlamaForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                attn_implementation=attn_implementation,
                torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
                **bnb_model_from_pretrained_args
            )
bfshi commented 3 months ago

Hi @Chloe1997,

Yes to train LLaVA with S2, you can't use the original pre-trained projector from LLaVA because the mm_hidden_size is different. You can either train the projector yourself or use the projector from the pre-trained LLaVA with S2. The mm_hidden_size in the config of the LLaVA-S2 checkpoint should be 3072 already so probably no need to change that.

To train LLaVA with S2, you can use the latest LLaVA repo which has S2 integrated (see the PR here), and apply an additional change here. Then you can train LLaVA with S2 just like how you train a regular LLaVA, except for two new configs added: s2=True and s2_scales="336,672,1008". Please see the instructions of how to train LLaVA with S2 here.