open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
8.3k stars 2.62k forks source link

decode.acc_seg drops to 0.0 quickly and leads to 0s evaluation scores on custom dataset #1746

Open zlyin opened 2 years ago

zlyin commented 2 years ago

Hi there, I'm pretty new to this framework and came across a wired situation that my decode.acc_seg drops to 0 quickly and stays there. As a result, my evaluation results are 0s except the image dice metric, as shown here

022-07-08 07:06:31,867 - mmseg - INFO - Exp name: baseline_config.py
2022-07-08 07:06:31,867 - mmseg - INFO - Iter [20000/20000]    lr: 1.014e-09, layer_0_lr: 1.001e-09, eta: 0:00:00, time: 0.656, data_time: 0.207, memory: 8618, decode.loss_ce: 0.2459, decode.loss_dice: 0.9530, decode.loss_hd: 0.2908, decode.acc_seg: 0.0000, loss: 1.4897
2022-07-08 07:11:06,708 - mmseg - INFO - per class results:
2022-07-08 07:11:06,709 - mmseg - INFO - 
+-------------+-------+-----+------+
|    Class    | iDice | Acc | Dice |
+-------------+-------+-----+------+
| large_bowel | 62.11 | 0.0 | 0.0  |
| small_bowel | 69.45 | 0.0 | 0.0  |
|   stomach   | 76.57 | 0.0 | 0.0  |
+-------------+-------+-----+------+
2022-07-08 07:11:06,709 - mmseg - INFO - Summary:
2022-07-08 07:11:06,710 - mmseg - INFO - 
+------+-------+--------+------+-------+
| aAcc | fwIoU | miDice | mAcc | mDice |
+------+-------+--------+------+-------+
| 0.0  |  0.0  | 69.38  | 0.0  |  0.0  |
+------+-------+--------+------+-------+

My config file is as follows though

num_classes = 3

# ^ model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)  # ^ Segmentation usually uses SyncBN 
loss = [
    dict(type='CrossEntropyLoss', loss_name="loss_ce", use_sigmoid=True, 
        loss_weight=10.0),
    # ^ add dice loss
    dict(type='SMPDiceLoss', loss_name="loss_dice", mode="multilabel", 
        alpha=0.5, beta=0.5, loss_weight=1.0),
    # ^ add Hausdorff Dist
    dict(type='HausdorffDist', loss_name="loss_hd", mode="multilabel", 
        loss_weight=1.0),
]
model = dict(
    # ^ ----------------- use SMP model ----------- #
    type='SMPUnet',
    # ! for encoder part
    backbone=dict(
        type='timm-efficientnet-b2',    # encoder_name
        pretrained="imagenet",
        # depth = 5
    ),
    # ! for decoder part
    decode_head=dict(
        num_classes=num_classes,
        align_corners=False,
        loss_decode=loss,       # ! feed define losses to model
    ),
    train_cfg=dict(),
    test_cfg=dict(mode="whole", multi_label=True),  # whole or sliding
)

# ^ dataset settings
dataset_type = 'CustomDataset'
data_root = 'data/tract/'
classes = ['large_bowel', 'small_bowel', 'stomach']
palette = [[0,0,0], [128,128,128], [255,255,255]]
# ! will be casted to np.float32 automatically; but MUST be 3-channel images!
img_norm_cfg = dict(mean=[0,0,0], std=[1,1,1], to_rgb=True) 

size = 384

albu_train_transforms = [
    # dict(type='RandomBrightnessContrast', p=0.5),
    dict(type="GridDistortion", num_steps=5, distort_limit=0.1, p=0.2),
    dict(type="ElasticTransform", alpha=1, sigma=30, alpha_affine=30, p=0.3), 
    dict(type="OpticalDistortion", distort_limit=0.05, shift_limit=0.05, p=0.3),
    dict(type="Blur", blur_limit=3, p=0.1),
]
train_pipeline = [
    dict(type='LoadImageFromFile', to_float32=True, color_type='unchanged', 
        max_value='max', force_3chan=False),
    dict(type='LoadAnnotations'),
    dict(type='Resize', 
        img_scale=(size, size),     # ^ None, (w, h)
        # ratio_range=(0.8, 1.2),     # ^ sacling factor range, multiscale
        keep_ratio=True,            # ^ keep AR; ratio_range=(0.5, 2.0),
    ),
    dict(type='RandomFlip', prob=0.0, direction='horizontal'),  # ! no flip
    dict(type="RandomRotate", prob=0.3, degree=3, pad_val=0, seg_pad_val=255),
    # dict(type='Albu', transforms=albu_train_transforms),
    # dict(type="CLAHE", clip_limit=40, tile_grid_size=(8, 8)),
    dict(type='PhotoMetricDistortion', brightness_delta=32, contrast_range=(0.8, 1.2), 
        saturation_range=(0.8, 1.2), hue_delta=18),
    dict(type='Normalize', **img_norm_cfg), 
    dict(type='Pad', size=(size, size), pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),       # ^ collect data as default
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),  # ^ which keys will be sent to segmentor
]

test_pipeline = [
    dict(type='LoadImageFromFile', to_float32=True, color_type='unchanged', 
        max_value='max', force_3chan=False),
    dict(
        type='MultiScaleFlipAug',       # ^ TTA
        img_scale=(size, size),
        flip=False,                     # ! do not flip in TTA
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size=(size, size), pad_val=0, seg_pad_val=255),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ],
    ),
]

data = dict(
    samples_per_gpu=24,     # ! batch size here
    workers_per_gpu=4,
    train=dict(
        type=dataset_type,
        multi_label=True,
        data_root=data_root,
        img_dir='mmseg_train_25D3ChanS2/images',
        ann_dir='mmseg_train_25D3ChanS2/labels',
        # ^ img, mask format
        img_suffix=".png",
        seg_map_suffix='.png',
        split="mmseg_train_25D3ChanS2/splits/fold_0.txt",
        classes=classes,
        palette=palette,
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        multi_label=True,
        data_root=data_root,
        img_dir='mmseg_train_25D3ChanS2/images',
        ann_dir='mmseg_train_25D3ChanS2/labels',
        # ^ img, mask format
        img_suffix=".png",
        seg_map_suffix='.png',
        split="mmseg_train_25D3ChanS2/splits/holdout_0.txt",
        classes=classes,
        palette=palette,
        pipeline=test_pipeline),
    # ^ not related in training pipeline
    test=dict(
        type=dataset_type,
        multi_label=True,
        data_root=data_root,
        test_mode=True,
        img_dir='test/images',
        ann_dir='test/labels',
        img_suffix=".jpg",
        seg_map_suffix='.png',
        classes=classes,
        palette=palette,
        pipeline=test_pipeline),
)

log_config = dict(
    interval=50,
    hooks=[
        dict(type='CustomizedTextLoggerHook', by_epoch=False),
    ])

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None                    # pretrain model
resume_from = None                  # ckpt
workflow = [('train', 1)]
cudnn_benchmark = True          # use cudnn to accelearte fixed size input training

total_iters = 30    # per 1K

# ^ optimizer
optimizer = dict(
    type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=5e-4,
)
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')

# ^ learning policy
lr_config = dict(policy='CosineAnnealing',
                warmup='linear',
                warmup_iters=100,
                warmup_ratio=1e-6,  # ^ lr of wmup starts from lr9 * wmup_ratio
                min_lr=1e-9, 
                by_epoch=False)

# ^ runtime settings
find_unused_parameters=True
runner = dict(
    type='IterBasedRunner',         # IterBasedRunner or EpochBasedRunner
    max_iters=total_iters * 1000,
)
checkpoint_config = dict(by_epoch=False, interval=2000, save_optimizer=False)
evaluation = dict(by_epoch=False, interval=2000, metric=['imDice', 'mDice'], pre_eval=True)
fp16 = dict()

work_dir = f'./work_dirs/tract/baseline'

Thanks for your help!

MeowZheng commented 2 years ago

I have a little question about your img_norm_cfg. If you don't mind, would you like to tell us why to use this img_norm_cfg

img_norm_cfg = dict(mean=[0,0,0], std=[1,1,1], to_rgb=True) 
zlyin commented 2 years ago

n't mind, would you like to tell us why to use this img_norm_cfg

img_norm_cfg = dict(mean=[0,0,0], std=[1,1,1], to_rgb=True) 

Hi @MeowZheng This is because I don't know the image mean & std that the backbone was trained on. Since I'm applying the model to a medical image dataset, I'm not sure using imagenet stats would be beneficial here.