open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
7.9k stars 2.57k forks source link

4-Channel CYMK with single output segmentation map - Error during inference: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0 #3567

Open nittifra opened 6 months ago

nittifra commented 6 months ago

Hello,

I'm currently working on implementing a PSPNET (ResNet50) for segmentation. My goal is to process multimodal images with 4 channels encoded as a .png file with CYMK coding and output a single-channel segmentation map.

During training, everything seems to be functioning well. When training with RGB images (3 channels), I achieve a convergence at 0.7 Dice coefficient. However, when using the same configuration with a fourth channel, the convergence is slightly higher at 0.76 Dice coefficient. It's worth noting that I'm not using a pretrained network.

My intention behind adding this fourth channel is to incorporate another hybrid MRI modality for prostate tumor segmentation. Previously, when working solely with RGB images, the process worked smoothly, and I could compute metrics on the volume. To be clear, R, G, and B channels represent different MRI modalities.

Now, when I attempt to load CYMK images (4 channels), I encounter a warning:

warnings.warn(
Loads checkpoint by local backend from path: /content/drive/MyDrive/results_hybrid_images/iter_18000.pth
The model and loaded state dict do not match exactly
size mismatch for backbone.stem.0.weight: copying a param with shape torch.Size([32, 3, 3, 3]) from checkpoint, the shape in the current model is torch.Size([32, 4, 3, 3]).

Then, in the script I get this error.

result = inference_model(model,cymk_image)  <------
pred_label = result.pred_sem_seg.data.squeeze()
pred_label = pred_label.cpu().numpy().astype(np.uint8)
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 0

To provide context, I'm calculating the real performance metrics on a test set. I preprocess the original images into hybrid images with 4 channels, maintaining the same field of view (FOV), resulting in images with a shape of (256,256,4). This automation is crucial as I'll be evaluated on a blind test set. Although I process single slices, I eventually reconstruct the volume.

Below is the configuration I'm using, with "<------" highlighting the modifications from the original 3-channel configuration:

Thank you very much for any tips or insights you can provide. I've searched for similar issues, but they are mostly about training rather than inference models.


compile = False
crop_size = (
    256,
    256,
)
data_root = ...
dataset_type = 'prostateMRI'
default_hooks = dict(
    checkpoint=dict(by_epoch=False, interval=2000, type='CheckpointHook'),
    logger=dict(interval=1000, log_metric_by_epoch=False, type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    timer=dict(type='IterTimerHook'),
    visualization=dict(draw=True, interval=500, type='SegVisualizationHook'))
default_scope = 'mmseg'
env_cfg = dict(
    cudnn_benchmark=True,
    dist_cfg=dict(backend='nccl'),
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
img_ratios = [
    0.75,
    1.0,
    1.25,
]
log_level = 'INFO'
log_processor = dict(by_epoch=False)
model = dict(
    auxiliary_head=dict(
        align_corners=False,
        channels=256,
        concat_input=False,
        dropout_ratio=0.1,
        ignore_index=255,
        in_channels=1024,
        in_index=2,
        loss_decode=dict(
            loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False),
        norm_cfg=dict(requires_grad=True, type='BN'),
        num_classes=2,
        num_convs=1,
        type='FCNHead'),
    backbone=dict(
        contract_dilation=True,
        depth=50,
        dilations=(
            1,
            1,
            2,
            4,
        ),

        in_channels=4,                          <--------

        norm_cfg=dict(requires_grad=True, type='BN'),
        norm_eval=False,
        num_stages=4,
        out_indices=(
            0,
            1,
            2,
            3,
        ),
        strides=(
            1,
            2,
            1,
            1,
        ),
        style='pytorch',
        type='ResNetV1c'),
    data_preprocessor=dict(

        mean=[                          <----------
            0,
            0,
            0,
            0,
        ],

        pad_val=0,
        seg_pad_val=255,
        size=(
            256,
            256,
        ),

        std=[                                 <----------
            1,
            1,
            1,
            1,
        ],

        type='SegDataPreProcessor'),
    decode_head=dict(
        align_corners=False,
        channels=512,
        dropout_ratio=0.1,
        ignore_index=255,
        in_channels=2048,
        in_index=3,
        loss_decode=[
            dict(loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=False),
        ],
        norm_cfg=dict(requires_grad=True, type='BN'),
        num_classes=2,
        pool_scales=(
            1,
            2,
            3,
            6,
        ),
        type='PSPHead'),
    pretrained=None,
    test_cfg=dict(mode='whole'),
    train_cfg=dict(),
    type='EncoderDecoder')
norm_cfg = dict(requires_grad=True, type='BN')
optim_wrapper = dict(
    clip_grad=dict(max_norm=1, norm_type=2),
    optimizer=dict(
        betas=(
            0.9,
            0.999,
        ), lr=0.0001, type='AdamW', weight_decay=0.001),
    type='AmpOptimWrapper')
optimizer = dict(
    betas=(
        0.9,
        0.999,
    ), lr=0.0001, type='AdamW', weight_decay=0.001)
param_scheduler = dict(
    begin=2000,
    by_epoch=False,
    end=10000,
    gamma=0.1,
    milestones=[
        6000,
        8000,
    ],
    type='MultiStepLR')
randomness = dict(seed=0)
resume = False
save_dir = '/content/drive/MyDrive/ProgettoEIM/results_hybrid_images'
test_cfg = dict(type='TestLoop')
test_dataloader = dict(
    batch_size=1,
    dataset=dict(
        data_prefix=dict(img_path='img_dir/test', seg_map_path='ann_dir/test'),
        data_root='/content/drive/MyDrive/ProgettoEIM/hybrid_images',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations'),
            dict(type='PackSegInputs'),
        ],
        type='prostateMRI'),
    num_workers=1,
    persistent_workers=True,
    sampler=dict(shuffle=False, type='DefaultSampler'))
test_evaluator = dict(
    ignore_index=255, iou_metrics=[
        'mIoU',
        'mDice',
    ], type='IoUMetric')
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs'),
]
train_cfg = dict(max_iters=28385, type='IterBasedTrainLoop', val_interval=810)
train_dataloader = dict(
    batch_size=8,
    dataset=dict(
        data_prefix=dict(
            img_path='img_dir/train', seg_map_path='ann_dir/train'),
        data_root='/content/drive/MyDrive/ProgettoEIM/hybrid_images',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations'),
            dict(degree=(
                -5.0,
                5.0,
            ), prob=0.5, type='RandomRotate'),
            dict(type='PackSegInputs'),
        ],
        type='prostateMRI'),
    num_workers=2,
    persistent_workers=True,
    sampler=dict(shuffle=True, type='InfiniteSampler'))
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(degree=(
        -5.0,
        5.0,
    ), prob=0.5, type='RandomRotate'),
    dict(type='PackSegInputs'),
]
val_cfg = dict(type='ValLoop')
val_dataloader = dict(
    batch_size=1,
    dataset=dict(
        data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
        data_root='/content/drive/MyDrive/ProgettoEIM/hybrid_images',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations'),
            dict(type='PackSegInputs'),
        ],
        type='prostateMRI'),
    num_workers=1,
    persistent_workers=True,
    sampler=dict(shuffle=False, type='DefaultSampler'))
val_evaluator = dict(
    ignore_index=255, iou_metrics=[
        'mIoU',
        'mDice',
    ], type='IoUMetric')
visualizer = dict(
    classes=[
        'healty',
        'tumor',
    ],
    dataset_name='prostateMRI',
    name='visualizer',
    palette=[
        (
            0,
            0,
            0,
        ),
        (
            128,
            0,
            128,
        ),
    ],
    save_dir='/content/drive/MyDrive/ProgettoEIM/results_hybrid_images',
    type='SegLocalVisualizer',
    vis_backends=[
        dict(type='LocalVisBackend'),
        dict(type='TensorboardVisBackend'),
    ])
work_dir = ...
Zoulinx commented 5 months ago

It seems like your weight file only supports three-channel input.