soeaver / Parsing-R-CNN

Parsing R-CNN for Instance-Level Human Analysis
MIT License
298 stars 35 forks source link

incompatible weights #21

Open melih-unsal opened 3 years ago

melih-unsal commented 3 years ago

When I run the test code I got the error due to the mismatch of the weights

size mismatch for RPN.anchor_generator.cell_anchors.0: copying a param with shape torch.Size([3, 4]) from checkpoint, the shape in current model is torch.Size([15, 4]). size mismatch for RPN.head.conv.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([2048, 2048, 3, 3]). size mismatch for RPN.head.conv.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([2048]). size mismatch for RPN.head.cls_logits.weight: copying a param with shape torch.Size([3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([15, 2048, 1, 1]). size mismatch for RPN.head.cls_logits.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([15]). size mismatch for RPN.head.bbox_pred.weight: copying a param with shape torch.Size([12, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([60, 2048, 1, 1]). size mismatch for RPN.head.bbox_pred.bias: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([60]).

pangfeng1985 commented 3 years ago

@melih1996 have you solve this issue?

zimratK commented 3 years ago

Hi! I have solved the issue by calling merge_cfg_from_file(path_to_config) before initializing the Generalized_RCNN:

from rcnn.core.config import cfg, merge_cfg_from_file
from rcnn.modeling.model_builder import Generalized_RCNN
from utils.net import convert_bn2affine_model
from utils.checkpointer import get_weights, load_weights
import torch

path_to_config = "some/path/to/config"
merge_cfg_from_file(path_to_config)
model = Generalized_RCNN(is_train=False)
# Load trained model
path_to_model = "some/path/to/model_latest.pth"
cfg.TEST.WEIGHTS = get_weights(cfg.CKPT, path_to_model)
load_weights(model, cfg.TEST.WEIGHTS)
if cfg.MODEL.BATCH_NORM == 'freeze':
    model = convert_bn2affine_model(model)
model.eval()
model.to(torch.device(cfg.DEVICE))