open-mmlab / mmcv

OpenMMLab Computer Vision Foundation
https://mmcv.readthedocs.io/en/latest/
Apache License 2.0
5.92k stars 1.66k forks source link

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 512, 512] instead #3012

Open dengrunqin opened 10 months ago

dengrunqin commented 10 months ago

Prerequisite

Environment

mmcv-full | 1.4.5 |  

mmdet | 2.25.1 |  

Reproduces the problem - code sample

I'm trying to use my own trained instance segmentation model inference to perform detection inference on a real image dataset, but the following error occurs: RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3- dimensional input of size [3, 512, 512] instead

Reproduces the problem - command or script

Copyright (c) OpenMMLab. All rights reserved.

import asyncio from argparse import ArgumentParser

from mmdet.apis import (async_inference_detector, inference_detector, init_detector, show_result_pyplot) import os import tqdm

def parse_args(): parser = ArgumentParser() parser.add_argument('--img', default='G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\data\segment2', help='Image file') parser.add_argument('--config', default='G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\work_dirs\mask_rcnn_r50_fpn_1x_coco\mask_rcnn_r50_fpn_1x_coco.py', help='Config file') parser.add_argument('--checkpoint', default='G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\work_dirs\mask_rcnn_r50_fpn_1x_coco\latest.pth', help='Checkpoint file') parser.add_argument('--out-file', default='G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master/results_record\SZ_result', help='Path to output file') parser.add_argument('--device', default='cuda:0', help='Device used for inference') parser.add_argument( '--palette', default='voc', choices=['coco', 'voc', 'citys', 'random'], help='Color palette used for visualization') parser.add_argument('--score-thr', type=float, default=0.3, help='bbox score threshold') args = parser.parse_args() return args

def main(args):

build the model from a config file and a checkpoint file

model = init_detector(args.config, args.checkpoint, device=args.device)

for filename in tqdm.tqdm(os.listdir(args.img)):
    img = os.path.join(args.img, filename)
    result = inference_detector(model, img)
    out_file = os.path.join(args.out_file, filename)
    show_result_pyplot(
        model,
        img,
        result,
        palette=args.palette,
        score_thr=args.score_thr,
        out_file=out_file)

if name == 'main': args = parse_args() main(args)

Reproduces the problem - error message

load checkpoint from local path: G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\work_dirs\mask_rcnn_r50_fpn_1x_coco\latest.pth The model and loaded state dict do not match exactly

unexpected key in source state_dict: backbone.attention1.fc.0.weight, backbone.attention1.fc.2.weight, backbone.attention2.fc.0.weight, backbone.attention2.fc.2.weight, backbone.attention3.fc.0.weight, backbone.attention3.fc.2.weight

0%| | 0/2788 [00:00<?, ?it/s] Traceback (most recent call last): File "G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\model_output.py", line 135, in main(args) File "G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\model_output.py", line 122, in main result = inference_detector(model, img) File "G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\mmdet\apis\inference.py", line 151, in inference_detector results = model(return_loss=False, rescale=True, data) File "C:\Users\99745.conda\envs\mmdet\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl result = self.forward(*input, *kwargs) File "g:\dengrq\mmcv-1.4.5\mmcv\runner\fp16_utils.py", line 109, in new_func return old_func(args, kwargs) File "G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\mmdet\models\detectors\base.py", line 174, in forward return self.forward_test(img, img_metas, kwargs) File "G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\mmdet\models\detectors\base.py", line 147, in forward_test return self.simple_test(imgs[0], img_metas[0], kwargs) File "G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\mmdet\models\detectors\two_stage.py", line 177, in simple_test x = self.extract_feat(img) File "G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\mmdet\models\detectors\two_stage.py", line 67, in extract_feat x = self.backbone(img) File "C:\Users\99745.conda\envs\mmdet\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl result = self.forward(*input, *kwargs) File "G:\dengrq\PyCharmProjects\pythonProject\MMDetection\mmdetection-master\mmdet\models\backbones\resnet.py", line 636, in forward x = self.conv1(x) File "C:\Users\99745.conda\envs\mmdet\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl result = self.forward(input, **kwargs) File "C:\Users\99745.conda\envs\mmdet\lib\site-packages\torch\nn\modules\conv.py", line 399, in forward return self._conv_forward(input, self.weight, self.bias) File "C:\Users\99745.conda\envs\mmdet\lib\site-packages\torch\nn\modules\conv.py", line 395, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 512, 512] instead

Additional information

一、The expectation is to get instance segmentation detection results in this way. 二、When using the config and checkpoint for target detection, you can get the results of target detection, but when you run the instance segmentation model, you will get the above error.

zhouzaida commented 10 months ago

Hi, you can try to add a breakpoint or print the data at https://github.com/open-mmlab/mmdetection/blob/b95583270c57b3b0dc9c0523b2d1ebe46b755cca/mmdet/apis/inference.py#L142 to see its value

data = collate(datas, samples_per_gpu=len(imgs))
print(data)

By the way, did you use GPUs? Please also provide the command to reproduce the error.