haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
19.6k stars 2.16k forks source link

[Question] Finetune with chinese-clip #943

Open chenchun0629 opened 9 months ago

chenchun0629 commented 9 months ago

Question

motivation:

I try to use chinese-clip replace clip.

environment

$ uname -a 
Linux localhost.localdomain 3.10.0-1160.80.1.el7.x86_64 #1 SMP Tue Nov 8 15:48:59 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux

$ pip
llava                         1.1.3  /data/jupyter/user/cc/LLaVA
torch                         2.0.1
accelerate                    0.21.0
deepspeed                     0.12.6
flash-attn                    2.3.6

$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

step 1: Pretrain

#!/bin/bash

deepspeed llava/train/train_mem.py \
    --deepspeed ./scripts/zero2.json \
    --model_name_or_path /data/llm_models/vicuna-7b-v1.5 \
    --version plain \
    --data_path /data/llm_datasets/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \
    --image_folder /data/llm_datasets/LLaVA-Pretrain \
    --vision_tower OFA-Sys/chinese-clip-vit-large-patch14-336px \
    --mm_projector_type mlp2x_gelu \
    --tune_mm_mlp_adapter True \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --bf16 True \
    --output_dir ./checkpoints/vicuna-v1.5-7b-chinese-clip-pretrain \
    --num_train_epochs 1 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 24000 \
    --save_total_limit 1 \
    --learning_rate 1e-3 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb

The current pretraining step has been successfully executed.

step 2: Visual Instruction Tuning

deepspeed llava/train/train_mem.py \
    --deepspeed ./scripts/zero3_offload_yi.json \
    --model_name_or_path /data/llm_models/vicuna-7b-v1.5 \
    --version v1 \
    --data_path /data/llm_datasets/LLaVA-Visual-Instruction-Tuning/llava_v1_5_mix665k.coco.json \
    --image_folder /data/llm_datasets/LLaVA-Visual-Instruction-Tuning \
    --vision_tower OFA-Sys/chinese-clip-vit-large-patch14-336px \
    --pretrain_mm_mlp_adapter ./checkpoints/vicuna-v1.5-7b-chinese-clip-pretrain/mm_projector.bin \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir ./checkpoints/vicuna-v1.5-7b-chinese-clip-finetune \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb

I have promble in this current step.

return self.vision_model(                                                                                                                                                            [32/1954]
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/data/jupyter/user/cc/LLaVA-cc/llava/train/train_mem.py", line 17, in <module>
    train()
  File "/data/jupyter/user/cc/LLaVA/llava/train/train.py", line 965, in train
    trainer.train()
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2654, in training_step
    loss = self.compute_loss(model, inputs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 2679, in compute_loss
    outputs = model(**inputs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
    result = forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/chinese_clip/modeling_chinese_clip.py", line 1084, in forward
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1833, in forward
    loss = self.module(*inputs, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
      File "/data/jupyter/user/cc/LLaVA/llava/model/language_model/llava_llama.py", line 79, in forward
    ) = self.prepare_inputs_labels_for_multimodal(
hidden_states = self.embeddings(pixel_values)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/data/jupyter/user/cc/LLaVA/llava/model/llava_arch.py", line 121, in prepare_inputs_labels_for_multimodal
    image_features = self.encode_images(images).to(self.device)
  File "/data/jupyter/user/cc/LLaVA/llava/model/llava_arch.py", line 95, in encode_images
    image_features = self.get_model().get_vision_tower()(images)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/data/jupyter/user/cc/LLaVA/llava/model/multimodal_encoder/chinese_clip_encoder.py", line 48, in forward
    image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/chinese_clip/modeling_chinese_clip.py", line 1340, in forward
    return self.vision_model(
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/chinese_clip/modeling_chinese_clip.py", line 1084, in forward
    hidden_states = self.embeddings(pixel_values)
    result = forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/chinese_clip/modeling_chinese_clip.py", line 204, in forward
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
    embeddings = embeddings + self.position_embedding(self.position_ids)
RuntimeError: The size of tensor a (257) must match the size of tensor b (577) at non-singleton dimension 1
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/models/chinese_clip/modeling_chinese_clip.py", line 204, in forward
    embeddings = embeddings + self.position_embedding(self.position_ids)
RuntimeError: The size of tensor a (257) must match the size of tensor b (577) at non-singleton dimension 1

May I ask if there are any suggestions on how I can solve this problem.

Thanks!

chenchun0629 commented 9 months ago

some code

# builder.py

import os
from .clip_encoder import CLIPVisionTower
from .chinese_clip_encoder import ChineseCLIPVisionTower

def build_vision_tower(vision_tower_cfg, **kwargs):
    vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
    is_absolute_path_exists = os.path.exists(vision_tower)
    if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
        return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
    elif 'chinese-clip' in vision_tower:
        return ChineseCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)

    raise ValueError(f'Unknown vision tower: {vision_tower}')
# chinese_clip_encoder.py

import torch
import torch.nn as nn

from transformers import ChineseCLIPVisionModel, ChineseCLIPImageProcessor, ChineseCLIPVisionConfig

class ChineseCLIPVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.vision_tower_name = vision_tower
        self.select_layer = args.mm_vision_select_layer
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')

        if not delay_load:
            self.load_model()
        else:
            self.cfg_only = ChineseCLIPVisionConfig.from_pretrained(self.vision_tower_name)

    def load_model(self):
        self.image_processor = ChineseCLIPImageProcessor.from_pretrained(self.vision_tower_name)
        self.vision_tower = ChineseCLIPVisionModel.from_pretrained(self.vision_tower_name)
        self.vision_tower.requires_grad_(False)

        self.is_loaded = True

    def feature_select(self, image_forward_outs):
        image_features = image_forward_outs.hidden_states[self.select_layer]
        if self.select_feature == 'patch':
            image_features = image_features[:, 1:]
        elif self.select_feature == 'cls_patch':
            image_features = image_features
        else:
            raise ValueError(f'Unexpected select feature: {self.select_feature}')
        return image_features

    @torch.no_grad()
    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
                image_feature = self.feature_select(image_forward_out).to(image.dtype)
                image_features.append(image_feature)
        else:
            image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
            image_features = self.feature_select(image_forward_outs).to(images.dtype)

        return image_features

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        return self.vision_tower.dtype

    @property
    def device(self):
        return self.vision_tower.device

    @property
    def config(self):
        if self.is_loaded:
            return self.vision_tower.config
        else:
            return self.cfg_only

    @property
    def hidden_size(self):
        return self.config.hidden_size

    @property
    def num_patches(self):
        return (self.config.image_size // self.config.patch_size) ** 2
chenchun0629 commented 9 months ago

After debugging, it was found that the image_size will decrease from the original 336 to 224

LlavaLlamaForCausalLM.images.shape: torch.Size([1, 3, 336, 336])
LlavaMetaForCausalLM0.images.shape: torch.Size([1, 3, 336, 336])
LlavaMetaForCausalLM.images.shape: torch.Size([1, 3, 336, 336])
ChineseCLIPVisionTower.images.shape: torch.Size([1, 3, 336, 336])
LlavaLlamaForCausalLM.images.shape: torch.Size([1, 3, 336, 336])
LlavaMetaForCausalLM0.images.shape: torch.Size([1, 3, 336, 336])
LlavaMetaForCausalLM.images.shape: torch.Size([1, 3, 336, 336])
ChineseCLIPVisionTower.images.shape: torch.Size([1, 3, 336, 336])
{'loss': 2.6496, 'learning_rate': 3.1826973105233184e-06, 'epoch': 0.0}                                                                                                                       
  0%|                                                                                                                                                  | 4/202394 
LlavaLlamaForCausalLM.images.shape: torch.Size([1, 3, 224, 224])
LlavaMetaForCausalLM0.images.shape: torch.Size([1, 3, 224, 224])
LlavaMetaForCausalLM.images.shape: torch.Size([1, 3, 224, 224])
ChineseCLIPVisionTower.images.shape: torch.Size([1, 3, 224, 224])
LlavaLlamaForCausalLM.images.shape: torch.Size([1, 3, 224, 224])
LlavaMetaForCausalLM0.images.shape: torch.Size([1, 3, 224, 224])
LlavaMetaForCausalLM.images.shape: torch.Size([1, 3, 224, 224])
ChineseCLIPVisionTower.images.shape: torch.Size([1, 3, 224, 224])
Traceback (most recent call last):
  File "/data/jupyter/user/cc/LLaVA-cc/llava/train/train_mem.py", line 17, in <module>
    train()
  File "/data/jupyter/user/cc/LLaVA/llava/train/train.py", line 965, in train
    trainer.train()
  File "/data/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
...
...
chenchun0629 commented 9 months ago
# 336*336 in LazySupervisedDataset and DataCollatorForSupervisedDataset
LazySupervisedDataset.images: torch.Size([3, 336, 336])
DataCollatorForSupervisedDataset.images: torch.Size([3, 336, 336])
# 224*224 in transformers.trainer
trainer0.images: torch.Size([1, 3, 224, 224])
trainer1.images: torch.Size([1, 3, 224, 224])
# transformers.trainer
            for step, inputs in enumerate(epoch_iterator):
                print("trainer0.images:", inputs["images"].shape, flush=True)
                total_batched_samples += 1
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
                    continue
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None

                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

                with self.accelerator.accumulate(model):
                    print("trainer1.images:", inputs["images"].shape, flush=True)
                    tr_loss_step = self.training_step(model, inputs)
chenchun0629 commented 9 months ago

I try use blip_laion_cc_sbu_558k.json datasets to finetune model, is works.

When I use llava_v1_5_mix665k.json datasets, I encounter this issue.

GITMrzk commented 8 months ago

你的维度问题看起来就是224和336clip出的特征维度对不上吧 257 和 577不是正对着224和336 clip patch 14时出的image token数么,加的1就是cls token

ScottishFold007 commented 6 months ago

请问,你这个问题解决了吗?