open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
28.5k stars 9.28k forks source link

Unable to get TTA to work with inference_detector on single demo image #11742

Open lewj85 opened 1 month ago

lewj85 commented 1 month ago

Trying to test TTA on the demo image, but inference_detector expects a batch (NCHW). Is there something wrong with the config file?

Code

from mmdet.apis import inference_detector, init_detector
from mmengine.config import Config, ConfigDict

config_file = './checkpoints/rtmdet_tiny_8xb32-300e_coco.py'
checkpoint_file = './checkpoints/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'

config = Config.fromfile(config_file)
config.model = ConfigDict(**config.tta_model, module=config.model)
config.test_dataloader.dataset.pipeline = config.tta_pipeline

# Fails due to no backbone here https://github.com/open-mmlab/mmdetection/issues/10355
#model = init_detector(cfg, checkpoint_file, device='cuda:0')

# Solution taken from here https://github.com/open-mmlab/mmdetection/blob/main/demo/large_image_demo.py#L105
#if 'init_cfg' in config.model.backbone:
#    config.model.backbone.init_cfg = None
assert 'tta_model' in config, 'Cannot find ``tta_model`` in config.' \
                                " Can't use tta !"
assert 'tta_pipeline' in config, 'Cannot find ``tta_pipeline`` ' \
                                    "in config. Can't use tta !"
config.model = ConfigDict(**config.tta_model, module=config.model)
test_data_cfg = config.test_dataloader.dataset
while 'dataset' in test_data_cfg:
    test_data_cfg = test_data_cfg['dataset']

test_data_cfg.pipeline = config.tta_pipeline

# TODO: TTA mode will error if cfg_options is not set.
#  This is an mmdet issue and needs to be fixed later.
# build the model from a config file and a checkpoint file
model = init_detector(
    config, checkpoint_file, device='cuda:0', cfg_options={})

img = './demo/demo.jpg'
result = inference_detector(model, img)

Error AssertionError: The input of ImgDataPreprocessor should be a NCHW tensor or a list of tensor, but got a tensor with shape: torch.Size([3, 960, 960])