apple / ml-cvnets

CVNets: A library for training computer vision networks
https://apple.github.io/ml-cvnets
Other
1.76k stars 225 forks source link

VIT-tiny weights and config dont match? #92

Open nemcekova opened 1 year ago

nemcekova commented 1 year ago

Hello, I'm trying to create VIT-tiny model, which is mentioned here. My approach is:

First, I create model structure, as it was advised in the issue. I use options.ops get_training_arguments and cvnets get_model(), with vit-tiny.yaml config file. Model structure is saved correctly.

Secondly, I load the model structure and load_state_dict.

weights = 'vit-tiny.pt'
model = torch.load('vit_structure.pt', map_location=torch.device('cpu'))
model.load_state_dict(torch.load(weights, map_location=torch.device('cpu')))

That's where I'm getting following error:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for VisionTransformer: Missing key(s) in state_dict: "patch_emb.0.block.conv.weight", "patch_emb.0.block.norm.weight", "patch_emb.0.block.norm.bias", "patch_emb.0.block.norm.running_mean", "patch_emb.0.block.norm.running_var", "patch_emb.1.block.conv.weight", "patch_emb.1.block.norm.weight", "patch_emb.1.block.norm.bias", "patch_emb.1.block.norm.running_mean", "patch_emb.1.block.norm.running_var", "patch_emb.2.block.conv.weight", "patch_emb.2.block.conv.bias", "post_transformer_norm.weight", "post_transformer_norm.bias", "pos_embed.pos_embed.pos_embed". Unexpected key(s) in state_dict: "patch_emb.block.conv.weight", "patch_emb.block.conv.bias", "transformer.12.weight", "transformer.12.bias", "pos_embed.pe".

Any idea how is it possible and how to fix it?

Note: This approach works perfectly with mobile-vit

Tranbaber commented 11 months ago

@nemcekova Hello! I'm trying to train MobileViT model, but I'm having the following problem and am asking for help

File "C:\Users\72344.conda\envs\MobileViTv2\Scripts\cvnets-train.exemain.py", line 4, in ModuleNotFoundError: No module named 'main_train'

And I tried to download this module, but show "ERROR: Could not find a version that satisfies the requirement main_train (from versions: none) ERROR: No matching distribution found for main_train"

Can you tell what can I do? Thank you very much!