chequanghuy / TwinLiteNet

MIT License
127 stars 32 forks source link

Exporting Onnx in export.py #11

Closed byungjinku closed 7 months ago

byungjinku commented 7 months ago

I want to export the pretrained model to onnx but I've got these error messages:

Missing key(s) in state_dict: "encoder.level1.conv.weight", "encoder.level1.bn.weight", "encoder.level1.bn.bias", "encoder.level1.bn.running_mean", "encoder.level1.bn.running_var", "encoder.level1.act.weight", "encoder.b1.conv.weight", "encoder.b1.bn.weight", "encoder.b1.bn.bias", "encoder.b1.bn.running_mean", "encoder.b1.bn.running_var", "encoder.b1.act.weight", "encoder.level2_0.c1.conv.weight", "encoder.level2_0.d1.conv.weight", "encoder.level2_0.d2.conv.weight", "encoder.level2_0.d4.conv.weight", "encoder.level2_0.d8.conv.weight", "encoder.level2_0.d16.conv.weight", "encoder.level2_0.bn.weight", "encoder.level2_0.bn.bias", "encoder.level2_0.bn.running_mean", "encoder.level2_0.bn.running_var", "encoder.level2_0.act.weight", "encoder.level2.0.c1.conv.weight", "encoder.level2.0.d1.conv.weight", "encoder.level2.0.d2.conv.weight", "encoder.level2.0.d4.conv.weight", "encoder.level2.0.d8.conv.weight", "encoder.level2.0.d16.conv.weight", "encoder.level2.0.bn.bn.weight", "encoder.level2.0.bn.bn.bias", "encoder.level2.0.bn.bn.running_mean", "encoder.level2.0.bn.bn.running_var", "encoder.level2.0.bn.act.weight", "encoder.level2.1.c1.conv.weight", "encoder.level2.1.d1.conv.weight", "encoder.level2.1.d2.conv.weight", "encoder.level2.1.d4.conv.weight", "encoder.level2.1.d8.conv.weight", "encoder.level2.1.d16.conv.weight", "encoder.level2.1.bn.bn.weight", "encoder.level2.1.bn.bn.bias", "encoder.level2.1.bn.bn.running_mean", "encoder.level2.1.bn.bn.running_var", "encoder.level2.1.bn.act.weight", "encoder.b2.conv.weight", "encoder.b2.bn.weight", "encoder.b2.bn.bias", "encoder.b2.bn.running_mean", "encoder.b2.bn.running_var", "encoder.b2.act.weight", "encoder.level3_0.c1.conv.weight", "encoder.level3_0.d1.conv.weight", "encoder.level3_0.d2.conv.weight", "encoder.level3_0.d4.conv.weight", "encoder.level3_0.d8.conv.weight", "encoder.level3_0.d16.conv.weight", "encoder.level3_0.bn.weight", "encoder.level3_0.bn.bias", "encoder.level3_0.bn.running_mean", "encoder.level3_0.bn.running_var", "encoder.level3_0.act.weight", "encoder.level3.0.c1.conv.weight", "encoder.level3.0.d1.conv.weight", "encoder.level3.0.d2.conv.weight", "encoder.level3.0.d4.conv.weight", "encoder.level3.0.d8.conv.weight", "encoder.level3.0.d16.conv.weight", "encoder.level3.0.bn.bn.weight", "encoder.level3.0.bn.bn.bias", "encoder.level3.0.bn.bn.running_mean", "encoder.level3.0.bn.bn.running_var", "encoder.level3.0.bn.act.weight", "encoder.level3.1.c1.conv.weight", "encoder.level3.1.d1.conv.weight", "encoder.level3.1.d2.conv.weight", "encoder.level3.1.d4.conv.weight", "encoder.level3.1.d8.conv.weight", "encoder.level3.1.d16.conv.weight", "encoder.level3.1.bn.bn.weight", "encoder.level3.1.bn.bn.bias", "encoder.level3.1.bn.bn.running_mean", "encoder.level3.1.bn.bn.running_var", "encoder.level3.1.bn.act.weight", "encoder.level3.2.c1.conv.weight", "encoder.level3.2.d1.conv.weight", "encoder.level3.2.d2.conv.weight", "encoder.level3.2.d4.conv.weight", "encoder.level3.2.d8.conv.weight", "encoder.level3.2.d16.conv.weight", "encoder.level3.2.bn.bn.weight", "encoder.level3.2.bn.bn.bias", "encoder.level3.2.bn.bn.running_mean", "encoder.level3.2.bn.bn.running_var", "encoder.level3.2.bn.act.weight", "encoder.b3.conv.weight", "encoder.b3.bn.weight", "encoder.b3.bn.bias", "encoder.b3.bn.running_mean", "encoder.b3.bn.running_var", "encoder.b3.act.weight", "encoder.sa.gamma", "encoder.sa.query_conv.weight", "encoder.sa.query_conv.bias", "encoder.sa.key_conv.weight", "encoder.sa.key_conv.bias", "encoder.sa.value_conv.weight", "encoder.sa.value_conv.bias", "encoder.sc.gamma", "encoder.conv_sa.conv.weight", "encoder.conv_sa.bn.weight", "encoder.conv_sa.bn.bias", "encoder.conv_sa.bn.running_mean", "encoder.conv_sa.bn.running_var", "encoder.conv_sa.act.weight", "encoder.conv_sc.conv.weight", "encoder.conv_sc.bn.weight", "encoder.conv_sc.bn.bias", "encoder.conv_sc.bn.running_mean", "encoder.conv_sc.bn.running_var", "encoder.conv_sc.act.weight", "encoder.classifier.conv.weight", "encoder.classifier.bn.weight", "encoder.classifier.bn.bias", "encoder.classifier.bn.running_mean", "encoder.classifier.bn.running_var", "encoder.classifier.act.weight", "up_1_1.deconv.weight", "up_1_1.bn.weight", "up_1_1.bn.bias", "up_1_1.bn.running_mean", "up_1_1.bn.running_var", "up_1_1.act.weight", "up_2_1.deconv.weight", "up_2_1.bn.weight", "up_2_1.bn.bias", "up_2_1.bn.running_mean", "up_2_1.bn.running_var", "up_2_1.act.weight", "up_1_2.deconv.weight", "up_1_2.bn.weight", "up_1_2.bn.bias", "up_1_2.bn.running_mean", "up_1_2.bn.running_var", "up_1_2.act.weight", "up_2_2.deconv.weight", "up_2_2.bn.weight", "up_2_2.bn.bias", "up_2_2.bn.running_mean", "up_2_2.bn.running_var", "up_2_2.act.weight", "classifier_1.deconv.weight", "classifier_1.bn.weight", "classifier_1.bn.bias", "classifier_1.bn.running_mean", "classifier_1.bn.running_var", "classifier_1.act.weight", "classifier_2.deconv.weight", "classifier_2.bn.weight", "classifier_2.bn.bias", "classifier_2.bn.running_mean", "classifier_2.bn.running_var", "classifier_2.act.weight". Unexpected key(s) in state_dict: "module.encoder.level1.conv.weight", "module.encoder.level1.bn.weight", "module.encoder.level1.bn.bias", "module.encoder.level1.bn.running_mean", "module.encoder.level1.bn.running_var", "module.encoder.level1.bn.num_batches_tracked", "module.encoder.level1.act.weight", "module.encoder.b1.conv.weight", "module.encoder.b1.bn.weight", "module.encoder.b1.bn.bias", "module.encoder.b1.bn.running_mean", "module.encoder.b1.bn.running_var", "module.encoder.b1.bn.num_batches_tracked", "module.encoder.b1.act.weight", "module.encoder.level2_0.c1.conv.weight", "module.encoder.level2_0.d1.conv.weight", "module.encoder.level2_0.d2.conv.weight", "module.encoder.level2_0.d4.conv.weight", "module.encoder.level2_0.d8.conv.weight", "module.encoder.level2_0.d16.conv.weight", "module.encoder.level2_0.bn.weight", "module.encoder.level2_0.bn.bias", "module.encoder.level2_0.bn.running_mean", "module.encoder.level2_0.bn.running_var", "module.encoder.level2_0.bn.num_batches_tracked", "module.encoder.level2_0.act.weight", "module.encoder.level2.0.c1.conv.weight", "module.encoder.level2.0.d1.conv.weight", "module.encoder.level2.0.d2.conv.weight", "module.encoder.level2.0.d4.conv.weight", "module.encoder.level2.0.d8.conv.weight", "module.encoder.level2.0.d16.conv.weight", "module.encoder.level2.0.bn.bn.weight", "module.encoder.level2.0.bn.bn.bias", "module.encoder.level2.0.bn.bn.running_mean", "module.encoder.level2.0.bn.bn.running_var", "module.encoder.level2.0.bn.bn.num_batches_tracked", "module.encoder.level2.0.bn.act.weight", "module.encoder.level2.1.c1.conv.weight", "module.encoder.level2.1.d1.conv.weight", "module.encoder.level2.1.d2.conv.weight", "module.encoder.level2.1.d4.conv.weight", "module.encoder.level2.1.d8.conv.weight", "module.encoder.level2.1.d16.conv.weight", "module.encoder.level2.1.bn.bn.weight", "module.encoder.level2.1.bn.bn.bias", "module.encoder.level2.1.bn.bn.running_mean", "module.encoder.level2.1.bn.bn.running_var", "module.encoder.level2.1.bn.bn.num_batches_tracked", "module.encoder.level2.1.bn.act.weight", "module.encoder.b2.conv.weight", "module.encoder.b2.bn.weight", "module.encoder.b2.bn.bias", "module.encoder.b2.bn.running_mean", "module.encoder.b2.bn.running_var", "module.encoder.b2.bn.num_batches_tracked", "module.encoder.b2.act.weight", "module.encoder.level3_0.c1.conv.weight", "module.encoder.level3_0.d1.conv.weight", "module.encoder.level3_0.d2.conv.weight", "module.encoder.level3_0.d4.conv.weight", "module.encoder.level3_0.d8.conv.weight", "module.encoder.level3_0.d16.conv.weight", "module.encoder.level3_0.bn.weight", "module.encoder.level3_0.bn.bias", "module.encoder.level3_0.bn.running_mean", "module.encoder.level3_0.bn.running_var", "module.encoder.level3_0.bn.num_batches_tracked", "module.encoder.level3_0.act.weight", "module.encoder.level3.0.c1.conv.weight", "module.encoder.level3.0.d1.conv.weight", "module.encoder.level3.0.d2.conv.weight", "module.encoder.level3.0.d4.conv.weight", "module.encoder.level3.0.d8.conv.weight", "module.encoder.level3.0.d16.conv.weight", "module.encoder.level3.0.bn.bn.weight", "module.encoder.level3.0.bn.bn.bias", "module.encoder.level3.0.bn.bn.running_mean", "module.encoder.level3.0.bn.bn.running_var", "module.encoder.level3.0.bn.bn.num_batches_tracked", "module.encoder.level3.0.bn.act.weight", "module.encoder.level3.1.c1.conv.weight", "module.encoder.level3.1.d1.conv.weight", "module.encoder.level3.1.d2.conv.weight", "module.encoder.level3.1.d4.conv.weight", "module.encoder.level3.1.d8.conv.weight", "module.encoder.level3.1.d16.conv.weight", "module.encoder.level3.1.bn.bn.weight", "module.encoder.level3.1.bn.bn.bias", "module.encoder.level3.1.bn.bn.running_mean", "module.encoder.level3.1.bn.bn.running_var", "module.encoder.level3.1.bn.bn.num_batches_tracked", "module.encoder.level3.1.bn.act.weight", "module.encoder.level3.2.c1.conv.weight", "module.encoder.level3.2.d1.conv.weight", "module.encoder.level3.2.d2.conv.weight", "module.encoder.level3.2.d4.conv.weight", "module.encoder.level3.2.d8.conv.weight", "module.encoder.level3.2.d16.conv.weight", "module.encoder.level3.2.bn.bn.weight", "module.encoder.level3.2.bn.bn.bias", "module.encoder.level3.2.bn.bn.running_mean", "module.encoder.level3.2.bn.bn.running_var", "module.encoder.level3.2.bn.bn.num_batches_tracked", "module.encoder.level3.2.bn.act.weight", "module.encoder.b3.conv.weight", "module.encoder.b3.bn.weight", "module.encoder.b3.bn.bias", "module.encoder.b3.bn.running_mean", "module.encoder.b3.bn.running_var", "module.encoder.b3.bn.num_batches_tracked", "module.encoder.b3.act.weight", "module.encoder.sa.gamma", "module.encoder.sa.query_conv.weight", "module.encoder.sa.query_conv.bias", "module.encoder.sa.key_conv.weight", "module.encoder.sa.key_conv.bias", "module.encoder.sa.value_conv.weight", "module.encoder.sa.value_conv.bias", "module.encoder.sc.gamma", "module.encoder.conv_sa.conv.weight", "module.encoder.conv_sa.bn.weight", "module.encoder.conv_sa.bn.bias", "module.encoder.conv_sa.bn.running_mean", "module.encoder.conv_sa.bn.running_var", "module.encoder.conv_sa.bn.num_batches_tracked", "module.encoder.conv_sa.act.weight", "module.encoder.conv_sc.conv.weight", "module.encoder.conv_sc.bn.weight", "module.encoder.conv_sc.bn.bias", "module.encoder.conv_sc.bn.running_mean", "module.encoder.conv_sc.bn.running_var", "module.encoder.conv_sc.bn.num_batches_tracked", "module.encoder.conv_sc.act.weight", "module.encoder.classifier.conv.weight", "module.encoder.classifier.bn.weight", "module.encoder.classifier.bn.bias", "module.encoder.classifier.bn.running_mean", "module.encoder.classifier.bn.running_var", "module.encoder.classifier.bn.num_batches_tracked", "module.encoder.classifier.act.weight", "module.up_1_1.deconv.weight", "module.up_1_1.bn.weight", "module.up_1_1.bn.bias", "module.up_1_1.bn.running_mean", "module.up_1_1.bn.running_var", "module.up_1_1.bn.num_batches_tracked", "module.up_1_1.act.weight", "module.up_2_1.deconv.weight", "module.up_2_1.bn.weight", "module.up_2_1.bn.bias", "module.up_2_1.bn.running_mean", "module.up_2_1.bn.running_var", "module.up_2_1.bn.num_batches_tracked", "module.up_2_1.act.weight", "module.up_1_2.deconv.weight", "module.up_1_2.bn.weight", "module.up_1_2.bn.bias", "module.up_1_2.bn.running_mean", "module.up_1_2.bn.running_var", "module.up_1_2.bn.num_batches_tracked", "module.up_1_2.act.weight", "module.up_2_2.deconv.weight", "module.up_2_2.bn.weight", "module.up_2_2.bn.bias", "module.up_2_2.bn.running_mean", "module.up_2_2.bn.running_var", "module.up_2_2.bn.num_batches_tracked", "module.up_2_2.act.weight", "module.classifier_1.deconv.weight", "module.classifier_1.bn.weight", "module.classifier_1.bn.bias", "module.classifier_1.bn.running_mean", "module.classifier_1.bn.running_var", "module.classifier_1.bn.num_batches_tracked", "module.classifier_1.act.weight", "module.classifier_2.deconv.weight", "module.classifier_2.bn.weight", "module.classifier_2.bn.bias", "module.classifier_2.bn.running_mean", "module.classifier_2.bn.running_var", "module.classifier_2.bn.num_batches_tracked", "module.classifier_2.act.weight".

There sombody could help me ?

harrylal commented 7 months ago

I encountered similar issue while exporting the pretrained model to ONNX. To resolve this, you can wrap line 127 in export.py as follows:

model = torch.nn.DataParallel(model) # to ensure the state dictionary matches during loading.
model.load_state_dict(torch.load(weights))
model = model.module #unwrapping because ONNX export doesn't vibe with DataParallel
byungjinku commented 7 months ago

harrylal commented Feb 24, 2024

Thank you for your response but what you mean wrap i have to change the code or add the code?

the original code is this: model = Net() model = model.cuda() model.load_state_dict(torch.load(weights)) device = select_device(device)

harrylal commented 7 months ago

harrylal commented Feb 24, 2024

Thank you for your response but what you mean wrap i have to change the code or add the code?

the original code is this: model = Net() model = model.cuda() model.load_state_dict(torch.load(weights)) device = select_device(device)

After the changes it would be

    model = Net()
    model = model.cuda()
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(weights))
    model = model.module
    device = select_device(device)
byungjinku commented 7 months ago

harrylal commented Feb 24, 2024

Thank you for your response but what you mean wrap i have to change the code or add the code? the original code is this: model = Net() model = model.cuda() model.load_state_dict(torch.load(weights)) device = select_device(device)

After the changes it would be

    model = Net()
    model = model.cuda()
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(weights))
    model = model.module
    device = select_device(device)

Thank you very much!