Open rzamarefat opened 4 months ago
Hi,I'm also experiencing this issue, do you have a solution now?😔
timm ==0.6.5 torchvision ==0.15.2+cu118
timm ==0.6.5 torchvision ==0.15.2+cu118
Thanks for your reply, I've reinstalled both versions. But it still doesn't work after I retrain the model
Hello @rzamarefat and @Takagi0202, can you share the command you used to run it? Thanks.
Just like the closed issue,I run the training scripts and got my model.But the model can't be used to predict .And I get the error.This is the prediction script:
python prediction.py --p sample_prediction_data --n ed --f 10 It seems like that my trained model is different from your upload model. Can you tell me the reason?Thanks
Hello @rzamarefat, @Takagi0202, I cloned the repo into another folder and followed the instruction in the Readme, it works.
The culprit might be the timm
library. The model uses the timm
library before it's ported into Huggingface. Can you downgrade your timm
to timm==0.6.5
?
Update:
I see where the problem is now. When we train a new model, the current repo doesn't let you load the new weights (my fault :/). I haven't updated it from the dev version.
I have made some updates to the repo, can you update your repo?
The affected files by the update are prediction.py
, model/gencovit.py
and model/pred_func.py
Then you can run the following to test your new trained model:
Example usage:
python prediction.py --p DeepfakeTIMIT --d timit --f 10
To use VAE or ED variant:
VAE:
python prediction.py --p sample_prediction_data --v --f 10
ED:
python prediction.py --p sample_prediction_data --e --f 10
VAE test on DeepfakeTIMIT dataset:
python prediction.py --p DeepfakeTIMIT --v --d timit --f 10
run VAE and ED (GENCONVIT): this runs the provided weights as a defualt
python prediction.py --p sample_prediction_data --e --v --f 10
Testing a new model:
If you have trained a new model (e.g., if we have weight/genconvit_vae_May_16_2024_09_34_21.pth
) and want to test it, use the following:
VAE:
python prediction.py --p sample_prediction_data --v genconvit_vae_May_16_2024_09_34_21 --f 10
ED:
python prediction.py --p sample_prediction_data --e genconvit_ed_May_16_2024_10_18_09 --f 10
BOTH VAE and ED (GENCONVIT):
python prediction.py --p sample_prediction_data --e genconvit_ed_May_16_2024_10_18_09 --v genconvit_vae_May_16_2024_09_34_21 --f 10
I have changed the repo,and this time the model can run.Thanks so much!
Great!. Thank you.
I have changed the repo,and this time the model can run.Thanks so much!
Hello, could you please share the ACC of your test
Hi, Thank you for open sourcing your project. I have downloaded the provided checkpoints for both ed and vae and placed them inside the weight folder. However, I get the following errro:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for GenConViTED: Missing key(s) in state_dict: "backbone.patch_embed.backbone.layers.3.downsample.norm.weight", "backbone.patch_embed.backbone.layers.3.downsample.norm.bias", "backbone.patch_embed.backbone.layers.3.downsample.reduction.weight", "backbone.patch_embed.backbone.head.fc.weight", "backbone.patch_embed.backbone.head.fc.bias", "embedder.layers.3.downsample.norm.weight", "embedder.layers.3.downsample.norm.bias", "embedder.layers.3.downsample.reduction.weight", "embedder.head.fc.weight", "embedder.head.fc.bias". Unexpected key(s) in state_dict: "backbone.patch_embed.backbone.layers.0.downsample.norm.weight", "backbone.patch_embed.backbone.layers.0.downsample.norm.bias", "backbone.patch_embed.backbone.layers.0.downsample.reduction.weight", "backbone.patch_embed.backbone.layers.0.blocks.0.attn.relative_position_index", "backbone.patch_embed.backbone.layers.0.blocks.1.attn_mask", "backbone.patch_embed.backbone.layers.0.blocks.1.attn.relative_position_index", "backbone.patch_embed.backbone.layers.1.blocks.0.attn.relative_position_index", "backbone.patch_embed.backbone.layers.1.blocks.1.attn_mask", "backbone.patch_embed.backbone.layers.1.blocks.1.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.0.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.1.attn_mask", "backbone.patch_embed.backbone.layers.2.blocks.1.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.2.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.3.attn_mask", "backbone.patch_embed.backbone.layers.2.blocks.3.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.4.attn.relative_position_index", "backbone.patch_embed.backbone.layers.2.blocks.5.attn_mask", "backbone.patch_embed.backbone.layers.2.blocks.5.attn.relative_position_index", "backbone.patch_embed.backbone.layers.3.blocks.0.attn.relative_position_index", "backbone.patch_embed.backbone.layers.3.blocks.1.attn.relative_position_index", "backbone.patch_embed.backbone.head.weight", "backbone.patch_embed.backbone.head.bias", "embedder.layers.0.downsample.norm.weight", "embedder.layers.0.downsample.norm.bias", "embedder.layers.0.downsample.reduction.weight", "embedder.layers.0.blocks.0.attn.relative_position_index", "embedder.layers.0.blocks.1.attn_mask", "embedder.layers.0.blocks.1.attn.relative_position_index", "embedder.layers.1.blocks.0.attn.relative_position_index", "embedder.layers.1.blocks.1.attn_mask", "embedder.layers.1.blocks.1.attn.relative_position_index", "embedder.layers.2.blocks.0.attn.relative_position_index", "embedder.layers.2.blocks.1.attn_mask", "embedder.layers.2.blocks.1.attn.relative_position_index", "embedder.layers.2.blocks.2.attn.relative_position_index", "embedder.layers.2.blocks.3.attn_mask", "embedder.layers.2.blocks.3.attn.relative_position_index", "embedder.layers.2.blocks.4.attn.relative_position_index", "embedder.layers.2.blocks.5.attn_mask", "embedder.layers.2.blocks.5.attn.relative_position_index", "embedder.layers.3.blocks.0.attn.relative_position_index", "embedder.layers.3.blocks.1.attn.relative_position_index", "embedder.head.weight", "embedder.head.bias". size mismatch for backbone.patch_embed.backbone.layers.1.downsample.norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for backbone.patch_embed.backbone.layers.1.downsample.norm.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for backbone.patch_embed.backbone.layers.1.downsample.reduction.weight: copying a param with shape torch.Size([384, 768]) from checkpoint, the shape in current model is torch.Size([192, 384]). size mismatch for backbone.patch_embed.backbone.layers.2.downsample.norm.weight: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]). size mismatch for backbone.patch_embed.backbone.layers.2.downsample.norm.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]). size mismatch for backbone.patch_embed.backbone.layers.2.downsample.reduction.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]). size mismatch for embedder.layers.1.downsample.norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for embedder.layers.1.downsample.norm.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for embedder.layers.1.downsample.reduction.weight: copying a param with shape torch.Size([384, 768]) from checkpoint, the shape in current model is torch.Size([192, 384]). size mismatch for embedder.layers.2.downsample.norm.weight: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]). size mismatch for embedder.layers.2.downsample.norm.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]). size mismatch for embedder.layers.2.downsample.reduction.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]).