BAAI-DCAI / Bunny

A family of lightweight multimodal models.
Apache License 2.0
874 stars 66 forks source link

S2-Wrapper Strategy Training Resulting in Tensor Shape Mismatch #110

Closed dingtine closed 1 month ago

dingtine commented 1 month ago

When combining the model param by the , and fix the model with https://github.com/BAAI-DCAI/Bunny/issues/39#issuecomment-2021981770, the following error is reported at infer time

RuntimeError: mat1 and mat2 shapes cannot be multiplied (729x1152 and 3456x3584)

But, when traning without S2-Wrapper Strategy , the infer is correct.

how to fix it, thanks.

dingtine commented 1 month ago

my model struct:

BunnyQwen2ForCausalLM( (model): BunnyQwen2Model( (embed_tokens): Embedding(151646, 3584) (layers): ModuleList( (0-27): 28 x Qwen2DecoderLayer( (self_attn): Qwen2SdpaAttention( (q_proj): Linear(in_features=3584, out_features=3584, bias=True) (k_proj): Linear(in_features=3584, out_features=512, bias=True) (v_proj): Linear(in_features=3584, out_features=512, bias=True) (o_proj): Linear(in_features=3584, out_features=3584, bias=False) (rotary_emb): Qwen2RotaryEmbedding() ) (mlp): Qwen2MLP( (gate_proj): Linear(in_features=3584, out_features=18944, bias=False) (up_proj): Linear(in_features=3584, out_features=18944, bias=False) (down_proj): Linear(in_features=18944, out_features=3584, bias=False) (act_fn): SiLU() ) (input_layernorm): Qwen2RMSNorm() (post_attention_layernorm): Qwen2RMSNorm() ) ) (norm): Qwen2RMSNorm() (vision_tower): SigLipVisionTower( (vision_tower): SigLipVisionModel( (vision_model): SigLipVisionTransformer( (embeddings): SigLipVisionEmbeddings( (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid) (position_embedding): Embedding(729, 1152) ) (encoder): SigLipEncoder( (layers): ModuleList( (0-25): 26 x SigLipEncoderLayer( (self_attn): SigLipAttention( (k_proj): Linear(in_features=1152, out_features=1152, bias=True) (v_proj): Linear(in_features=1152, out_features=1152, bias=True) (q_proj): Linear(in_features=1152, out_features=1152, bias=True) (out_proj): Linear(in_features=1152, out_features=1152, bias=True) ) (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True) (mlp): SigLipMLP( (activation_fn): PytorchGELUTanh() (fc1): Linear(in_features=1152, out_features=4304, bias=True) (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True) ) ) ) (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True) (head): Identity() ) ) ) (mm_projector): Sequential( (0): Linear(in_features=3456, out_features=3584, bias=True) (1): GELU(approximate='none') (2): Linear(in_features=3584, out_features=3584, bias=True) ) ) (lm_head): Linear(in_features=3584, out_features=151646, bias=False) )

Isaachhh commented 1 month ago

You need to modify modeling_bunny_qwen2.py to support S2. You can refer to modeling_bunny_llama.py in Bunny-v1.1-Llama-3-8B-V which supoorts S2 in quick-start snippet.

dingtine commented 1 month ago

Yes, following your method, I have successfully made the modifications. Thanks!

Part of the code is as follows:

` def build_vision_tower(vision_tower_cfg, kwargs):   | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))   |     | return SigLipVisionTowerS2(vision_tower, vision_tower_cfg=vision_tower_cfg, kwargs)

`