Francis-Rings / ILA

31 stars 3 forks source link

Error(s) in loading state_dict for XCLIP #1

Open poorfriend opened 1 year ago

poorfriend commented 1 year ago

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for XCLIP: size mismatch for visual.transformer.resblocks.0.message_attn.interactive_block.1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for visual.transformer.resblocks.1.message_attn.interactive_block.1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for visual.transformer.resblocks.2.message_attn.interactive_block.1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256]).

In model/align

class ILA(nn.Module):
    def __init__(self, T=8, d_model=768, patch_size=16, input_resolution=224, is_training=True):
        super().__init__()
        self.T = T
        self.W = input_resolution // patch_size
        self.d_model = d_model
        self.is_training = is_training

        self.interactive_block = nn.Sequential(
            nn.Conv2d(self.d_model * 2, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
Francis-Rings commented 1 year ago

Thanks for your questions! We suggest that you load the original weights of CLIP to the model and turn the parameter strict to False which refers to finetune the ILA part during training.

qinghuannn commented 1 year ago

Same issue. I printed some weights of the released model and the model in the code. It seems that the architecture of the released model does not match the architecture in the code. @Francis-Rings Could you help me to solve this problem?

code:

    for key in state_dict:
        if key.startswith("visual.transformer.resblocks.0.message_attn.interactive_block"):
            print(key, state_dict[key].shape)
    print("-"*50)
    print(model.visual.transformer.resblocks[0].message_attn.interactive_block)

log:

visual.transformer.resblocks.0.message_attn.interactive_block.0.depthwise_conv.0.weight torch.Size([1536, 1, 3, 3])
visual.transformer.resblocks.0.message_attn.interactive_block.0.depthwise_conv.0.bias torch.Size([1536])
visual.transformer.resblocks.0.message_attn.interactive_block.0.depthwise_conv.1.weight torch.Size([1536])
visual.transformer.resblocks.0.message_attn.interactive_block.0.depthwise_conv.1.bias torch.Size([1536])
visual.transformer.resblocks.0.message_attn.interactive_block.0.depthwise_conv.1.running_mean torch.Size([1536])
visual.transformer.resblocks.0.message_attn.interactive_block.0.depthwise_conv.1.running_var torch.Size([1536])
visual.transformer.resblocks.0.message_attn.interactive_block.0.depthwise_conv.1.num_batches_tracked torch.Size([])
visual.transformer.resblocks.0.message_attn.interactive_block.0.pointwise_conv.0.weight torch.Size([256, 1536, 1, 1])
visual.transformer.resblocks.0.message_attn.interactive_block.0.pointwise_conv.0.bias torch.Size([256])
visual.transformer.resblocks.0.message_attn.interactive_block.0.pointwise_conv.1.weight torch.Size([256])
visual.transformer.resblocks.0.message_attn.interactive_block.0.pointwise_conv.1.bias torch.Size([256])
visual.transformer.resblocks.0.message_attn.interactive_block.0.pointwise_conv.1.running_mean torch.Size([256])
visual.transformer.resblocks.0.message_attn.interactive_block.0.pointwise_conv.1.running_var torch.Size([256])
visual.transformer.resblocks.0.message_attn.interactive_block.0.pointwise_conv.1.num_batches_tracked torch.Size([])
visual.transformer.resblocks.0.message_attn.interactive_block.1.weight torch.Size([256, 256, 3, 3])
visual.transformer.resblocks.0.message_attn.interactive_block.1.bias torch.Size([256])
visual.transformer.resblocks.0.message_attn.interactive_block.2.weight torch.Size([256])
visual.transformer.resblocks.0.message_attn.interactive_block.2.bias torch.Size([256])
visual.transformer.resblocks.0.message_attn.interactive_block.2.running_mean torch.Size([256])
visual.transformer.resblocks.0.message_attn.interactive_block.2.running_var torch.Size([256])
visual.transformer.resblocks.0.message_attn.interactive_block.2.num_batches_tracked torch.Size([])
--------------------------------------------------
Sequential(
  (0): Conv2d(1536, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
)
Francis-Rings commented 1 year ago

Thanks for your questions! The current code is corresponding to architecture based on Something-Something v2, which integrates more powerful convolution-style components indicating stronger capacity of temporal modeling. If you still want to implement the architecture based on K400, you need to modify the components of class ILA. Concretely, you need to replace the original interactive_block in class ILA to the below codes:

self.interactive_block = nn.Sequential( Depth_Separable_Convolution(self.d_model * 2, 256), nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), )

Francis-Rings commented 1 year ago

I have already updated the code. The master branch is corresponding to K400 and the SSV2 branch refers to SSV2 architecture.