open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
8.16k stars 2.6k forks source link

Fine Tuning or Transfer learning Model trained on CityScape with a subset of classes to improve required class accuracy #1491

Closed waqarsqureshi closed 2 years ago

waqarsqureshi commented 2 years ago

I like to segment roads, curbs, line markings, and backgrounds on a dataset similar to cityscape or camvid. I see ocrnet is the SOTA for the cityscape dataset and tried to pretrained the model of ocrnet with the highest accuracy ocr_hr48_534x1024_160K. I see, while the results are good for many images, however, for some, the result of the road is not as it should be. The reason is some roads are in rural areas and since cityscape is trained on 19 classes so there is a higher chance of the wrong prediction. Therefore, I decided to train my dataset which is labeled into four classes and has a tree structure similar to cityscape164. Now, my question is which is not clear from the documentation is that: How can I transfer learn my dataset? Do I have to make a new model with a pretrained backbone and train the cityscape again with 4 classes and my dataset? or Can I only have to transfer learn the pretrained cityscape model with my new dataset with fewer classes. Do I have to write my own custom dataset class? or I just can do that using config files.

MengzhangLI commented 2 years ago

I think it may be caused by dataset domain shift, different classes in annotations and model performance.

Here are my personal suggestions:

Due to lack of dataset, you could re-label cityscapes dataset which only keeps 4 classes you want, just write a script to handle those annotations, for example, in cityscapes, the left 15 classes could all be set as background. Then train this cityscapes dataset and test on your dataset.

To sum up: (1) You can write a script and train model on cityscapes. Then use this checkpoint to inference on your own dataset. (2) I think transferring pretrained cityscapes model is not suitable because the checkpoint is trained on 19 classes, whose parameters are learned the representation on this more complicated situation (19 vs 4 classes), so its performance may be very bad on your 4-classes dataset. (3) I suggest you write your own custom dataset class, just follow pr of vaihingen dataset, reorganizing dataset folder first, create dataset class and set the config. Then change dataset config like here.

Best,

waqarsqureshi commented 2 years ago

update-1: I used the a scripts and changed label.py in cityscapesscripts to 3 classes to generate my custom cityscapes *labelTrainIds.png used in this library Ids TrainIds

Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
 Label(  'sky'                  , 23 ,        2 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),

rest of all trainIds as 255 in the TrainIds

update-2: In file I changed the following

    CLASSES = ('road', 'sidewalk', 'sky')
    PALETTE = [[128, 64, 128], [244, 35, 232], [70, 130, 180]]

to generate by dataset class.

update-3: Then I set the config as I like to use SegFormer model with 1024x1024 image size

update-4: The evaluation of cityscapes is done using cityscapesscripts evalPixelLevelSemanticLabeling. The evalPixelLevelSemanticLabeling file imports csHelpers( which imports labels.py), and labels from label.py. I changed the import to labels_road on both places. Also, I forgot to register the the customdataset "CityScapesRoad" in the init.py file. which I did by importing the cityscape_road.py from the mmseg/dataset and adding the CityScapeRoad dataset. Used the command tools/dist_train.sh configs/segformer/segformer_mit-b0_8x1_1024x1024_40k_cityscapes_road.py 1 --deterministic for 40K iteration to see the accuracy. P.S: The bug I made I realized after I run the training for 160K and it evaluated the validation dataset at 32K iteration, and it printed pixel accuracy for all 10 classes. The accuracies for road, sidewalk, sky( which was wrongly printed as building), were at 94, 83, 99% respectively.

waqarsqureshi commented 2 years ago

Is there any custom function for evaluation in custom.py. @MengzhangLI says NO, and he further added The metric evaluation function is called here: https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/apis/test.py#L219.

The custom.py does have its own evaluation function: https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/custom.py#L277-L296, which introduces from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics for metric calculation.

waqarsqureshi commented 2 years ago

Quick question:

The '../base/models/segformer_mit-b0_road.py', File has a variable model

model = dict(
   type='EncoderDecoder',
   pretrained=None,
   backbone=dict(
       type='MixVisionTransformer',
       in_channels=3,
       embed_dims=32,
       num_stages=4,
       num_layers=[2, 2, 2, 2],
       num_heads=[1, 2, 5, 8],
       patch_sizes=[7, 3, 3, 3],
       sr_ratios=[8, 4, 2, 1],
       out_indices=(0, 1, 2, 3),
       mlp_ratio=4,
       qkv_bias=True,
       drop_rate=0.0,
       attn_drop_rate=0.0,
       drop_path_rate=0.1),
   decode_head=dict(
       type='SegformerHead',
       in_channels=[32, 64, 160, 256],
       in_index=[0, 1, 2, 3],
       channels=256,
       dropout_ratio=0.1,
       num_classes=4,
       norm_cfg=norm_cfg,
       align_corners=False,
       loss_decode=dict(
           type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
   # model training and testing settings
   train_cfg=dict(),
   test_cfg=dict(mode='whole'))

Which is overwritten by the variable model in the Segformer_mit-b0_8x1_1024x1024_160k_cityscapes_road.py

model = dict(
   backbone=dict(
       init_cfg=dict(type='Pretrained', checkpoint='pretrain/mit_b0.pth')),
   test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
waqarsqureshi commented 2 years ago

The answer is NO. This is the hierarchical approach of the config file. It does not over right the variable. see the documentation in the config files

namKolorfuL commented 2 years ago

Hi @waqarsqureshi @MengzhangLI , really appreciate your effort on this. Following your footstep, I got stuck at the loss and iou of other class being NaN and only 1 class being shown. I did the followings:

@DATASETS.register_module() class Cityscapes5Dataset(CustomDataset): """Cityscapes dataset.

The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
"""
CLASSES = ('road','traffic sign','person','car', 'motorcycle')
PALETTE = [[128, 64, 128],[220, 220, 0],[220, 20, 60],[0, 0, 142],[0, 0, 230]]

def __init__(self,
             img_suffix='_leftImg8bit.png',
             seg_map_suffix='_gtFine_labelTrainIds.png',
             ignore_index=255,
             reduce_zero_label=False,
             **kwargs):
    super(Cityscapes5Dataset, self).__init__(
        img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, ignore_index=ignore_index, reduce_zero_label=reduce_zero_label, **kwargs)
I also add this to __init__.py
- Next, in configs/_base_/datasets I create a config file cityscapes_5class.py that looks like this:

dataset settings

dataset_type = 'Cityscapes5Dataset' data_root = 'cityscapes/' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) crop_size = (512, 1024) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2048, 1024),

img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],

    flip=False,
    transforms=[
        dict(type='Resize', keep_ratio=True),
        dict(type='RandomFlip'),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='ImageToTensor', keys=['img']),
        dict(type='Collect', keys=['img']),
    ])

] data = dict( samples_per_gpu=2, workers_per_gpu=2, train=dict( type=dataset_type, data_root=data_root, img_dir='leftImg8bit/train', ann_dir='gtFine/train', pipeline=train_pipeline), val=dict( type=dataset_type, data_root=data_root, img_dir='leftImg8bit/val', ann_dir='gtFine/val', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='leftImg8bit/val', ann_dir='gtFine/val', pipeline=test_pipeline))

- Finally, I add this config to my main config file and train the model

from mmseg.apis import set_random_seed base = [ 'configs/base/models/pspnet_r50-d8.py', 'configs/base/datasets/cityscapes_5class.py', 'configs/base/default_runtime.py', 'configs/base/schedules/schedule_20k.py' ]

- After  running for like 1000 iters, the evaluation result is like this:

+--------------+-------+-------+ | Class | IoU | Acc | +--------------+-------+-------+ | road | 82.67 | 100.0 | | traffic sign | nan | nan | | person | nan | nan | | car | nan | nan | | motorcycle | nan | nan | +--------------+-------+-------+


- Printing out prediction result of 1 image from the test set got me this. 
![test_city](https://user-images.githubusercontent.com/55870032/169667050-450ea0fd-65de-4f4b-985a-be7439df5f36.jpg)
I dont know where it could be wrong, pls help me out. I''ve tried everything to get rid of the nan in the metrics but failed so.
waqarsqureshi commented 2 years ago

How many classes do you have in the 'pspnet_r50-d8.py' file? if your trainID starts from 0 then it should match the classes you have in the num_classes variable. Otherwise, you can have one more class both in pspnet and in dataset.py to takecare of the 0 index. Do check the trainID in the *.png file and confirm that Ids matches with label.py file. I suggest to keep the label.py file name and only change the content not the filename, it as it is called in the number of places by the citscape library. I did not change all the unwanted classes to 255 instead of put them as background class and gave them an index. Hope it helps. In case this does not work - I suggest defining a new custom classes as suggested and follow the same. You then don't need the label.py then nor you need to use the cityscape library function during evaluation.

namKolorfuL commented 2 years ago

@waqarsqureshi can you explain a bit on how you set unwanted classes to background and set index? Currently, the Labels list in my labels.py looks like this, i set trainId of unwanted classes to 255, which is likely be the issue, I saw people having nan metrics by doing this as well.

labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,      255 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,      255 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,      255 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,      255 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,      255 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,      255 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,      255 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,      255 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,      255 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,      255 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,      255 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,      255 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,      255 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,      255 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,      255 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]

How should I modify to get your result? Thank you!

waqarsqureshi commented 2 years ago

I think you need to edit your label.py. Try to give label 1 to the sidewalk until 'person' label 2 to person and label 3 to the car and then label 4 to the motorcycle. You can simply delete all the rest or keep them to 2. The labels need to be in sequence.