BAAI-DCAI / Bunny

A family of lightweight multimodal models.
Apache License 2.0
921 stars 68 forks source link

Continuous Fine-tuning Bunny 1.1 4B #123

Closed ChenFicha closed 1 month ago

ChenFicha commented 2 months ago

I am continous fine-tuning Bunny 1.1 4B. I have a question about the training codes.

In train.py, when calling model.get_model().initialize_vision_modules(model_args=model_args), does it load the vision tower's weights from google/siglip-so400m-patch14-384 instead of from Bunny's weights? Since it is called after BunnyPhi3ForCausalLM.from_pretrained(), seem it would override the Bunny's vision tower weight.

Isaachhh commented 2 months ago

In continuous fine-tuning, I think the weights of vision tower would be loaded from Bunny instead of original siglip here.

What about you print some tensors of vision tower before and after calling model.get_model().initialize_vision_modules(model_args=model_args)?

ChenFicha commented 1 month ago

I think this is called by the BunnyPhi3ForCausalLM.from_pretrained() here. And the model.get_model().initialize_vision_modules(model_args=model_args) is called after BunnyPhi3ForCausalLM.from_pretrained().

I try to print some weight of vision tower with these codes in train.py and finetune.sh(I removed deepspeed for easier to access the weights):

    before = model.get_model().vision_tower.vision_tower.vision_model.encoder.layers[0].self_attn.k_proj.weight
    print("Before initialize_vision_modules():", before, "\n")

    model.get_model().initialize_vision_modules(model_args=model_args)

    after = model.get_model().vision_tower.vision_tower.vision_model.encoder.layers[0].self_attn.k_proj.weight
    print("After initialize_vision_modules():", after, "\n")

    from bunny.model.multimodal_encoder.siglip.siglip_encoder import SiglipVisionTower
    siglip = SiglipVisionTower("weights/siglip-so400m-patch14-384/", None)
    original = siglip.vision_tower.vision_model.encoder.layers[0].self_attn.k_proj.weight
    print("google/siglip-so400m-patch14-384:", original, "\n")

    print("If the weights are the same before and after initialize_vision_modules(): ", torch.equal(before, after))
    print("If the weights are the same after initialize_vision_modules() and google/siglip-so400m-patch14-384: ", torch.equal(after, original))
    exit()
python3 train.py \
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
    --model_name_or_path weights/Bunny-v1_1-4B/ \
    --model_type $MODEL_TYPE \
    --version phi3 \
    --data_path datasets/example.json \
    --image_folder datasets/temp/ \
    --vision_tower weights/siglip-so400m-patch14-384/ \

And here is what I get:

Before initialize_vision_modules(): Parameter containing:
tensor([[-4.7363e-02,  1.7929e-03,  4.0283e-02,  ..., -5.6396e-02,
          5.8899e-03, -9.0942e-03],
        [-6.8848e-02,  1.6724e-02,  7.0312e-02,  ..., -5.1514e-02,
         -2.5269e-02,  1.0132e-02],
        [-2.2656e-01,  3.0640e-02,  4.0771e-02,  ...,  1.8921e-03,
          1.3351e-04,  8.9111e-03],
        ...,
        [ 7.0801e-02, -1.8799e-02,  8.9844e-02,  ...,  1.0071e-03,
         -6.1768e-02,  9.8877e-03],
        [ 8.7280e-03, -2.5024e-02, -2.4780e-02,  ..., -1.6724e-02,
         -2.1582e-01, -1.0300e-03],
        [-3.7384e-03,  2.6093e-03, -7.2021e-03,  ..., -1.3962e-03,
         -3.1982e-02, -3.3417e-03]], dtype=torch.bfloat16) 

After initialize_vision_modules(): Parameter containing:
tensor([[-4.7381e-02,  1.7923e-03,  4.0312e-02,  ..., -5.6287e-02,
          5.9030e-03, -9.0787e-03],
        [-6.8910e-02,  1.6783e-02,  7.0460e-02,  ..., -5.1399e-02,
         -2.5219e-02,  1.0102e-02],
        [-2.2686e-01,  3.0666e-02,  4.0736e-02,  ...,  1.8918e-03,
          1.3389e-04,  8.9362e-03],
        ...,
        [ 7.0685e-02, -1.8835e-02,  8.9762e-02,  ...,  1.0074e-03,
         -6.1661e-02,  9.8732e-03],
        [ 8.7005e-03, -2.4973e-02, -2.4781e-02,  ..., -1.6682e-02,
         -2.1592e-01, -1.0275e-03],
        [-3.7415e-03,  2.6055e-03, -7.1938e-03,  ..., -1.3982e-03,
         -3.1892e-02, -3.3448e-03]]) 

google/siglip-so400m-patch14-384: Parameter containing:
tensor([[-4.7381e-02,  1.7923e-03,  4.0312e-02,  ..., -5.6287e-02,
          5.9030e-03, -9.0787e-03],
        [-6.8910e-02,  1.6783e-02,  7.0460e-02,  ..., -5.1399e-02,
         -2.5219e-02,  1.0102e-02],
        [-2.2686e-01,  3.0666e-02,  4.0736e-02,  ...,  1.8918e-03,
          1.3389e-04,  8.9362e-03],
        ...,
        [ 7.0685e-02, -1.8835e-02,  8.9762e-02,  ...,  1.0074e-03,
         -6.1661e-02,  9.8732e-03],
        [ 8.7005e-03, -2.4973e-02, -2.4781e-02,  ..., -1.6682e-02,
         -2.1592e-01, -1.0275e-03],
        [-3.7415e-03,  2.6055e-03, -7.1938e-03,  ..., -1.3982e-03,
         -3.1892e-02, -3.3448e-03]]) 

If the weights are the same before and after initialize_vision_modules():  False
If the weights are the same after initialize_vision_modules() and google/siglip-so400m-patch14-384:  True
Isaachhh commented 1 month ago

What if you add

if self.is_loaded:
    return

here?

ChenFicha commented 1 month ago

That's the way I am using now and it would not override the weights. I think it is a good idea to add the is_loaded check in all vision towers.

Isaachhh commented 1 month ago

Thank you so much for pointing out this bug!