erprogs / GenConViT

Deepfake Video Detection Using Generative Convolutional Vision Transformer
GNU General Public License v3.0
46 stars 9 forks source link

Error in loading state_dict #13

Open vivek-metaphy opened 1 month ago

vivek-metaphy commented 1 month ago
python prediction.py --p sample_prediction_data --v --f 10
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.12 (you have 1.4.11). Upgrade using: pip install --upgrade albumentations
CONFIG
{'model': {'backbone': 'convnext_tiny', 'embedder': 'swin_tiny_patch4_window7_224', 'latent_dims': 12544}, 'batch_size': 32, 'epoch': 1, 'learning_rate': 0.0001, 'weight_decay': 0.0001, 'num_classes': 2, 'img_size': 224, 'min_val_loss': 10000}

Using genconvit

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/convnext_tiny.in12k_ft_in1k)
INFO:timm.models._hub:[timm/convnext_tiny.in12k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/swin_tiny_patch4_window7_224.ms_in1k)
INFO:timm.models._hub:[timm/swin_tiny_patch4_window7_224.ms_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/swin_tiny_patch4_window7_224.ms_in1k)
INFO:timm.models._hub:[timm/swin_tiny_patch4_window7_224.ms_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/convnext_tiny.in12k_ft_in1k)
INFO:timm.models._hub:[timm/convnext_tiny.in12k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
Traceback (most recent call last):
  File "/Users/vivekkornepalli/Developer/AIPlay/GenConViT/prediction.py", line 342, in <module>
    main()
  File "/Users/vivekkornepalli/Developer/AIPlay/GenConViT/prediction.py", line 329, in main
    else vids(ed_weight, vae_weight, path, dataset, num_frames, net, fp16)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/vivekkornepalli/Developer/AIPlay/GenConViT/prediction.py", line 20, in vids
    model = load_genconvit(config, net, ed_weight, vae_weight, fp16)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/vivekkornepalli/Developer/AIPlay/GenConViT/model/pred_func.py", line 17, in load_genconvit
    model = GenConViT(
            ^^^^^^^^^^
  File "/Users/vivekkornepalli/Developer/AIPlay/GenConViT/model/genconvit.py", line 52, in __init__
    self.model_ed.load_state_dict(self.checkpoint_ed)
  File "/opt/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2189, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GenConViTED:
        Missing key(s) in state_dict: "backbone.patch_embed.backbone.layers.3.downsample.norm.weight", "backbone.patch_embed.backbone.layers.3.downsample.norm.bias", "backbone.patch_embed.backbone.layers.3.downsample.reduction.weight", "backbone.patch_embed.backbone.head.fc.weight", "backbone.patch_embed.backbone.head.fc.bias", "embedder.layers.3.downsample.norm.weight", "embedder.layers.3.downsample.norm.bias", "embedder.layers.3.downsample.reduction.weight", "embedder.head.fc.weight", "embedder.head.fc.bias". 
        Unexpected key(s) in state_dict: "backbone.patch_embed.backbone.layers.0.downsample.norm.weight", "backbone.patch_embed.backbone.layers.0.downsample.norm.bias", "backbone.patch_embed.backbone.layers.0.downsample.reduction.weight", "backbone.patch_embed.backbone.layers.0.blocks.0.attn.relative_position_index", "backbone.patch_embed.backbone.layers.0.blocks.1.attn_mask", "backbone.patch_embed.backbone.layers.0.blocks.1.attn.relative_position_index", "backbone.patch_embed.backbone.layers.1.blocks.0.attn.relative_position_index", "backbone.patch_embed.backbone.layers.1.blocks.1.attn_mask", "backbone.patch_embed.backbone.layers.1.blocks.1.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.0.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.1.attn_mask", "backbone.patch_embed.backbone.layers.2.blocks.1.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.2.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.3.attn_mask", "backbone.patch_embed.backbone.layers.2.blocks.3.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.4.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.5.attn_mask", "backbone.patch_embed.backbone.layers.2.blocks.5.attn.relative_position_index", "backbone.patch_embed.backbone.layers.3.blocks.0.attn.relative_position_index", "backbone.patch_embed.backbone.layers.3.blocks.1.attn.relative_position_index", "backbone.patch_embed.backbone.head.weight", "backbone.patch_embed.backbone.head.bias", "embedder.layers.0.downsample.norm.weight", "embedder.layers.0.downsample.norm.bias", "embedder.layers.0.downsample.reduction.weight", "embedder.layers.0.blocks.0.attn.relative_position_index", "embedder.layers.0.blocks.1.attn_mask", "embedder.layers.0.blocks.1.attn.relative_position_index", "embedder.layers.1.blocks.0.attn.relative_position_index", "embedder.layers.1.blocks.1.attn_mask", "embedder.layers.1.blocks.1.attn.relative_position_index", "embedder.layers.2.blocks.0.attn.relative_position_index", "embedder.layers.2.blocks.1.attn_mask", "embedder.layers.2.blocks.1.attn.relative_position_index", "embedder.layers.2.blocks.2.attn.relative_position_index", "embedder.layers.2.blocks.3.attn_mask", "embedder.layers.2.blocks.3.attn.relative_position_index", "embedder.layers.2.blocks.4.attn.relative_position_index", "embedder.layers.2.blocks.5.attn_mask", "embedder.layers.2.blocks.5.attn.relative_position_index", "embedder.layers.3.blocks.0.attn.relative_position_index", "embedder.layers.3.blocks.1.attn.relative_position_index", "embedder.head.weight", "embedder.head.bias". 
        size mismatch for backbone.patch_embed.backbone.layers.1.downsample.norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
        size mismatch for backbone.patch_embed.backbone.layers.1.downsample.norm.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
        size mismatch for backbone.patch_embed.backbone.layers.1.downsample.reduction.weight: copying a param with shape torch.Size([384, 768]) from checkpoint, the shape in current model is torch.Size([192, 384]).
        size mismatch for backbone.patch_embed.backbone.layers.2.downsample.norm.weight: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for backbone.patch_embed.backbone.layers.2.downsample.norm.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for backbone.patch_embed.backbone.layers.2.downsample.reduction.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]).
        size mismatch for embedder.layers.1.downsample.norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
        size mismatch for embedder.layers.1.downsample.norm.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
        size mismatch for embedder.layers.1.downsample.reduction.weight: copying a param with shape torch.Size([384, 768]) from checkpoint, the shape in current model is torch.Size([192, 384]).
        size mismatch for embedder.layers.2.downsample.norm.weight: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for embedder.layers.2.downsample.norm.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for embedder.layers.2.downsample.reduction.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]).                                                                                              

Seems like there is change in model architecture or missing some keys, not usre what the issue is. Hope this is easily resolvable and I am missing something. Please help

alilkkanyil commented 3 weeks ago

Same here