Jungjee / RawNet

Official repository for RawNet, RawNet2, and RawNet3
MIT License
357 stars 55 forks source link

can you share code to load the pretrained model? #18

Closed ma7555 closed 3 years ago

ma7555 commented 3 years ago

I am having difficulties loading the pretrained model. tried both model_RawNet2_original_code and model_RawNet2

from model_RawNet2 import RawNet2
from parser import get_args
import sys
import torch

sys.argv = ['RawNet-Pytorch.ipynb'] + ['-name'] + ['Rawnet']
args = get_args()
args.model['nb_classes'] = 6112

model = RawNet2(args.model)
model.load_state_dict(torch.load('./Pre-trained_model/rawnet2_best_weights.pt'))
model.eval()
RuntimeError: Error(s) in loading state_dict for RawNet2:
    Missing key(s) in state_dict: "block0.0.frm.fc.weight", "block0.0.frm.fc.bias", "block1.0.frm.fc.weight", "block1.0.frm.fc.bias", "block2.0.frm.fc.weight", "block2.0.frm.fc.bias", "block3.0.frm.fc.weight", "block3.0.frm.fc.bias", "block4.0.frm.fc.weight", "block4.0.frm.fc.bias", "block5.0.frm.fc.weight", "block5.0.frm.fc.bias". 
    Unexpected key(s) in state_dict: "fc_attention0.0.weight", "fc_attention0.0.bias", "fc_attention1.0.weight", "fc_attention1.0.bias", "fc_attention2.0.weight", "fc_attention2.0.bias", "fc_attention3.0.weight", "fc_attention3.0.bias", "fc_attention4.0.weight", "fc_attention4.0.bias", "fc_attention5.0.weight", "fc_attention5.0.bias". 

code with model_RawNet2_original_code

from model_RawNet2_original_code import RawNet
model2 = RawNet(args.model, 'gpu')
model2.load_state_dict(torch.load('./Pre-trained_model/rawnet2_best_weights.pt'))
model2.eval()

RuntimeError: Error(s) in loading state_dict for RawNet:
    Unexpected key(s) in state_dict: "block2.0.conv_downsample.weight", "block2.0.conv_downsample.bias". 
    size mismatch for block2.0.bn1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for block2.0.bn1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for block2.0.bn1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for block2.0.bn1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
    size mismatch for block2.0.conv1.weight: copying a param with shape torch.Size([256, 128, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3]).

Can you please share a reproducible code for loading the model?

Jungjee commented 3 years ago

Hi, first of all, using model_RawNet2_original_code is correct :) I tried to reproduce your situation, but found that the model weight does load well.

After running cp Pre-trained_model/model_RawNet2_original_code.py ./ in python/RawNet2, I simply ran the below code and model weights were loaded successfully.

If I recall correctly, there is another closed issue regarding model loading, and at that time I checked that it works. I will try to further look into this issue, but as of now, it seems okay.

import torch
from parser import get_args
from model_RawNet2_original_code import RawNet

args = get_args()
args.model['nb_classes'] = 6112
model = RawNet(args.model, torch.device('cpu'))
model.load_state_dict(torch.load('./Pre-trained_model/rawnet2_best_weights.pt', map_location=torch.device('cpu')))
print('succ')