open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
8.02k stars 2.57k forks source link

ValueError: size shape must match input shape. Input is 2D, size is 3 #672

Open kingbackyang opened 3 years ago

kingbackyang commented 3 years ago

Here is my self-defined dataset class: import os import os.path as osp from collections import OrderedDict from functools import reduce

import mmcv import numpy as np from mmcv.utils import print_log from prettytable import PrettyTable from torch.utils.data import Dataset

from mmseg.core import eval_metrics from mmseg.utils import get_root_logger from .builder import DATASETS from .pipelines import Compose

@DATASETS.register_module() class CellDataset(Dataset): """Custom dataset for semantic segmentation. An example of file structure is as followed.

.. code-block:: none

    ├── data
    │   ├── my_dataset
    │   │   ├── img_dir
    │   │   │   ├── train
    │   │   │   │   ├── xxx{img_suffix}
    │   │   │   │   ├── yyy{img_suffix}
    │   │   │   │   ├── zzz{img_suffix}
    │   │   │   ├── val
    │   │   ├── ann_dir
    │   │   │   ├── train
    │   │   │   │   ├── xxx{seg_map_suffix}
    │   │   │   │   ├── yyy{seg_map_suffix}
    │   │   │   │   ├── zzz{seg_map_suffix}
    │   │   │   ├── val

The img/gt_semantic_seg pair of CustomDataset should be of the same
except suffix. A valid img/gt_semantic_seg filename pair should be like
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
in the suffix). If split is given, then ``xxx`` is specified in txt file.
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
Please refer to ``docs/tutorials/new_dataset.md`` for more details.

Args:
    pipeline (list[dict]): Processing pipeline
    img_dir (str): Path to image directory
    img_suffix (str): Suffix of images. Default: '.jpg'
    ann_dir (str, optional): Path to annotation directory. Default: None
    seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
    split (str, optional): Split txt file. If split is specified, only
        file with suffix in the splits will be loaded. Otherwise, all
        images in img_dir/ann_dir will be loaded. Default: None
    data_root (str, optional): Data root for img_dir/ann_dir. Default:
        None.
    test_mode (bool): If test_mode=True, gt wouldn't be loaded.
    ignore_index (int): The label index to be ignored. Default: 255
    reduce_zero_label (bool): Whether to mark label zero as ignored.
        Default: False
    classes (str | Sequence[str], optional): Specify classes to load.
        If is None, ``cls.CLASSES`` will be used. Default: None.
    palette (Sequence[Sequence[int]]] | np.ndarray | None):
        The palette of segmentation map. If None is given, and
        self.PALETTE is None, random palette will be generated.
        Default: None
"""

CLASSES = ('cell')

PALETTE = [[255, 255, 255]]

def __init__(self,
             pipeline,
             img_dir,
             img_suffix='.jpg',
             ann_dir=None,
             seg_map_suffix='_mask.png',
             split=None,
             data_root=None,
             test_mode=False,
             ignore_index=255,
             reduce_zero_label=False,
             classes=None,
             palette=None):
    self.pipeline = Compose(pipeline)
    self.img_dir = img_dir
    self.img_suffix = img_suffix
    self.ann_dir = ann_dir
    self.seg_map_suffix = seg_map_suffix
    self.split = split
    self.data_root = data_root
    self.test_mode = test_mode
    self.ignore_index = ignore_index
    self.reduce_zero_label = reduce_zero_label
    self.label_map = None
    self.CLASSES, self.PALETTE = self.get_classes_and_palette(
        classes, palette)

    # join paths if data_root is specified
    if self.data_root is not None:
        if not osp.isabs(self.img_dir):
            self.img_dir = osp.join(self.data_root, self.img_dir)
        if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
            self.ann_dir = osp.join(self.data_root, self.ann_dir)
        if not (self.split is None or osp.isabs(self.split)):
            self.split = osp.join(self.data_root, self.split)

    # load annotations
    self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
                                           self.ann_dir,
                                           self.seg_map_suffix, self.split)

def __len__(self):
    """Total number of samples of data."""
    return len(self.img_infos)

def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
                     split):
    """Load annotation from directory.

    Args:
        img_dir (str): Path to image directory
        img_suffix (str): Suffix of images.
        ann_dir (str|None): Path to annotation directory.
        seg_map_suffix (str|None): Suffix of segmentation maps.
        split (str|None): Split txt file. If split is specified, only file
            with suffix in the splits will be loaded. Otherwise, all images
            in img_dir/ann_dir will be loaded. Default: None

    Returns:
        list[dict]: All image info of dataset.
    """

    img_infos = []
    if split is not None:
        with open(split) as f:
            for line in f:
                img_name = line.strip()
                img_info = dict(filename=img_name + img_suffix)
                if ann_dir is not None:
                    seg_map = img_name + seg_map_suffix
                    img_info['ann'] = dict(seg_map=seg_map)
                img_infos.append(img_info)
    else:
        for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
            img_info = dict(filename=img)
            if ann_dir is not None:
                seg_map = img.replace(img_suffix, seg_map_suffix)
                img_info['ann'] = dict(seg_map=seg_map)
            img_infos.append(img_info)

    print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
    return img_infos

def get_ann_info(self, idx):
    """Get annotation by index.

    Args:
        idx (int): Index of data.

    Returns:
        dict: Annotation info of specified index.
    """

    return self.img_infos[idx]['ann']

def pre_pipeline(self, results):
    """Prepare results dict for pipeline."""
    results['seg_fields'] = []
    results['img_prefix'] = self.img_dir
    results['seg_prefix'] = self.ann_dir
    if self.custom_classes:
        results['label_map'] = self.label_map

def __getitem__(self, idx):
    """Get training/test data after pipeline.

    Args:
        idx (int): Index of data.

    Returns:
        dict: Training/test data (with annotation if `test_mode` is set
            False).
    """

    if self.test_mode:
        return self.prepare_test_img(idx)
    else:
        return self.prepare_train_img(idx)

def prepare_train_img(self, idx):
    """Get training data and annotations after pipeline.

    Args:
        idx (int): Index of data.

    Returns:
        dict: Training data and annotation after pipeline with new keys
            introduced by pipeline.
    """

    img_info = self.img_infos[idx]
    ann_info = self.get_ann_info(idx)
    results = dict(img_info=img_info, ann_info=ann_info)
    self.pre_pipeline(results)
    return self.pipeline(results)

def prepare_test_img(self, idx):
    """Get testing data after pipeline.

    Args:
        idx (int): Index of data.

    Returns:
        dict: Testing data after pipeline with new keys introduced by
            pipeline.
    """

    img_info = self.img_infos[idx]
    results = dict(img_info=img_info)
    self.pre_pipeline(results)
    return self.pipeline(results)

def format_results(self, results, **kwargs):
    """Place holder to format result to dataset specific output."""

def get_gt_seg_maps(self, efficient_test=False):
    """Get ground truth segmentation maps for evaluation."""
    gt_seg_maps = []
    for img_info in self.img_infos:
        seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
        if efficient_test:
            gt_seg_map = seg_map
        else:
            gt_seg_map = mmcv.imread(
                seg_map, flag='unchanged', backend='pillow')
        gt_seg_maps.append(gt_seg_map)
    return gt_seg_maps

def get_classes_and_palette(self, classes=None, palette=None):
    """Get class names of current dataset.

    Args:
        classes (Sequence[str] | str | None): If classes is None, use
            default CLASSES defined by builtin dataset. If classes is a
            string, take it as a file name. The file contains the name of
            classes where each line contains one class name. If classes is
            a tuple or list, override the CLASSES defined by the dataset.
        palette (Sequence[Sequence[int]]] | np.ndarray | None):
            The palette of segmentation map. If None is given, random
            palette will be generated. Default: None
    """
    if classes is None:
        self.custom_classes = False
        return self.CLASSES, self.PALETTE

    self.custom_classes = True
    if isinstance(classes, str):
        # take it as a file path
        class_names = mmcv.list_from_file(classes)
    elif isinstance(classes, (tuple, list)):
        class_names = classes
    else:
        raise ValueError(f'Unsupported type {type(classes)} of classes.')

    if self.CLASSES:
        if not set(classes).issubset(self.CLASSES):
            raise ValueError('classes is not a subset of CLASSES.')

        # dictionary, its keys are the old label ids and its values
        # are the new label ids.
        # used for changing pixel labels in load_annotations.
        self.label_map = {}
        for i, c in enumerate(self.CLASSES):
            if c not in class_names:
                self.label_map[i] = -1
            else:
                self.label_map[i] = classes.index(c)

    palette = self.get_palette_for_custom_classes(class_names, palette)

    return class_names, palette

def get_palette_for_custom_classes(self, class_names, palette=None):

    if self.label_map is not None:
        # return subset of palette
        palette = []
        for old_id, new_id in sorted(
                self.label_map.items(), key=lambda x: x[1]):
            if new_id != -1:
                palette.append(self.PALETTE[old_id])
        palette = type(self.PALETTE)(palette)

    elif palette is None:
        if self.PALETTE is None:
            palette = np.random.randint(0, 255, size=(len(class_names), 3))
        else:
            palette = self.PALETTE

    return palette

def evaluate(self,
             results,
             metric='mIoU',
             logger=None,
             efficient_test=False,
             **kwargs):
    """Evaluate the dataset.

    Args:
        results (list): Testing results of the dataset.
        metric (str | list[str]): Metrics to be evaluated. 'mIoU',
            'mDice' and 'mFscore' are supported.
        logger (logging.Logger | None | str): Logger used for printing
            related information during evaluation. Default: None.

    Returns:
        dict[str, float]: Default metrics.
    """

    if isinstance(metric, str):
        metric = [metric]
    allowed_metrics = ['mIoU', 'mDice', 'mFscore']
    if not set(metric).issubset(set(allowed_metrics)):
        raise KeyError('metric {} is not supported'.format(metric))
    eval_results = {}
    gt_seg_maps = self.get_gt_seg_maps(efficient_test)
    if self.CLASSES is None:
        num_classes = len(
            reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
    else:
        num_classes = len(self.CLASSES)
    ret_metrics = eval_metrics(
        results,
        gt_seg_maps,
        num_classes,
        self.ignore_index,
        metric,
        label_map=self.label_map,
        reduce_zero_label=self.reduce_zero_label)

    if self.CLASSES is None:
        class_names = tuple(range(num_classes))
    else:
        class_names = self.CLASSES

    # summary table
    ret_metrics_summary = OrderedDict({
        ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
        for ret_metric, ret_metric_value in ret_metrics.items()
    })

    # each class table
    ret_metrics.pop('aAcc', None)
    ret_metrics_class = OrderedDict({
        ret_metric: np.round(ret_metric_value * 100, 2)
        for ret_metric, ret_metric_value in ret_metrics.items()
    })
    ret_metrics_class.update({'Class': class_names})
    ret_metrics_class.move_to_end('Class', last=False)

    # for logger
    class_table_data = PrettyTable()
    for key, val in ret_metrics_class.items():
        class_table_data.add_column(key, val)

    summary_table_data = PrettyTable()
    for key, val in ret_metrics_summary.items():
        if key == 'aAcc':
            summary_table_data.add_column(key, [val])
        else:
            summary_table_data.add_column('m' + key, [val])

    print_log('per class results:', logger)
    print_log('\n' + class_table_data.get_string(), logger=logger)
    print_log('Summary:', logger)
    print_log('\n' + summary_table_data.get_string(), logger=logger)

    # each metric dict
    for key, value in ret_metrics_summary.items():
        if key == 'aAcc':
            eval_results[key] = value / 100.0
        else:
            eval_results['m' + key] = value / 100.0

    ret_metrics_class.pop('Class', None)
    for key, value in ret_metrics_class.items():
        eval_results.update({
            key + '.' + str(name): value[idx] / 100.0
            for idx, name in enumerate(class_names)
        })

    if mmcv.is_list_of(results, str):
        for file_name in results:
            os.remove(file_name)
    return eval_results

I change "CLASSES" and "PALETTE" in custom.py. And the error shows that it seems gt seg image is not converted the gray image and is still the color image. How could I fix it? Thank you.

kingbackyang commented 3 years ago

The error report: Traceback (most recent call last): File "/home/mk/mmsegmentation/tools/train.py", line 166, in main() File "/home/mk/mmsegmentation/tools/train.py", line 162, in main meta=meta) File "/home/mk/mmsegmentation/mmseg/apis/train.py", line 116, in train_segmentor runner.run(data_loaders, cfg.workflow) File "/home/mk/anaconda3/envs/mmdet213/lib/python3.7/site-packages/mmcv/runner/iter_based_runner.py", line 131, in run iter_runner(iter_loaders[i], kwargs) File "/home/mk/anaconda3/envs/mmdet213/lib/python3.7/site-packages/mmcv/runner/iter_based_runner.py", line 60, in train outputs = self.model.train_step(data_batch, self.optimizer, kwargs) File "/home/mk/anaconda3/envs/mmdet213/lib/python3.7/site-packages/mmcv/parallel/data_parallel.py", line 67, in train_step return self.module.train_step(inputs[0], kwargs[0]) File "/home/mk/mmsegmentation/mmseg/models/segmentors/base.py", line 137, in train_step losses = self(data_batch) File "/home/mk/anaconda3/envs/mmdet213/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "/home/mk/anaconda3/envs/mmdet213/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 97, in new_func return old_func(args, kwargs) File "/home/mk/mmsegmentation/mmseg/models/segmentors/base.py", line 107, in forward return self.forward_train(img, img_metas, kwargs) File "/home/mk/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py", line 143, in forward_train gt_semantic_seg) File "/home/mk/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py", line 87, in _decode_head_forward_train self.train_cfg) File "/home/mk/mmsegmentation/mmseg/models/decode_heads/decode_head.py", line 185, in forward_train losses = self.losses(seg_logits, gt_semantic_seg) File "/home/mk/anaconda3/envs/mmdet213/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 184, in new_func return old_func(args, kwargs) File "/home/mk/mmsegmentation/mmseg/models/decode_heads/decode_head.py", line 220, in losses align_corners=self.align_corners) File "/home/mk/mmsegmentation/mmseg/ops/wrappers.py", line 26, in resize return F.interpolate(input, size, scale_factor, mode, align_corners) File "/home/mk/anaconda3/envs/mmdet213/lib/python3.7/site-packages/torch/nn/functional.py", line 3080, in interpolate 'Input is {}D, size is {}'.format(dim, len(size))) ValueError: size shape must match input shape. Input is 2D, size is 3

Junjun2016 commented 3 years ago

Hi @kingbackyang Sorry for the late reply! The annotation labels should be label index to represent the categories.

meiyihTan commented 2 years ago

Hi @kingbackyang Sorry for the late reply! The annotation labels should be label index to represent the categories.

Hi @Junjun2016 , I face the same error issue on getting ValueError: size shape must match input shape. Input is 2D, size is 3.

Can you give a more details/brief explanation or example on what is mean by "The annotation labels should be label index to represent the categories." ? What/where is the annotation labels and label index?

Thank you !

What I did was the same as @kingbackyang in his example. I made a new dataset as below:

import os.path as osp

from .builder import DATASETS from .custom import CustomDataset

@DATASETS.register_module() class ownL1Dataset(CustomDataset): """ownL1 dataset.

In segmentation map annotation for ownL1, 0 stands for background, which is
included in 2 categories. ``reduce_zero_label`` is fixed to False. The
``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
'_.png'.
"""

CLASSES = ('background', 'foreground')

PALETTE = [[120, 120, 120], [6, 230, 230]]

def __init__(self, **kwargs):
    super(ownL1Dataset, self).__init__(
        img_suffix='.jpg',
        seg_map_suffix='.png',
        reduce_zero_label=False,
        **kwargs)
    assert osp.exists(self.img_dir)`

I change "CLASSES" and "PALETTE" in custom.py.

and the error I get is about the same too(as below):


ValueError                                Traceback (most recent call last)
<ipython-input-12-2ebe9c9d56e4> in <module>()
     31 mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
     32 train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
---> 33                 meta=meta) #train_segmentor() in /content/focal_phi_loss_mmsegmentation/mmseg/apis/train.py 

14 frames
/content/focal_phi_loss_mmsegmentation/mmseg/apis/train.py in train_segmentor(model, dataset, cfg, distributed, validate, timestamp, meta)
    114     elif cfg.load_from:
    115         runner.load_checkpoint(cfg.load_from)
--> 116     runner.run(data_loaders, cfg.workflow)

/usr/local/lib/python3.7/dist-packages/mmcv/runner/iter_based_runner.py in run(self, data_loaders, workflow, max_iters, **kwargs)
    128                     if mode == 'train' and self.iter >= self._max_iters:
    129                         break
--> 130                     iter_runner(iter_loaders[i], **kwargs)
    131 
    132         time.sleep(1)  # wait for some hooks like loggers to finish

/usr/local/lib/python3.7/dist-packages/mmcv/runner/iter_based_runner.py in train(self, data_loader, **kwargs)
     58         self.call_hook('before_train_iter')
     59         data_batch = next(data_loader)
---> 60         outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
     61         if not isinstance(outputs, dict):
     62             raise TypeError('model.train_step() must return a dict')

/usr/local/lib/python3.7/dist-packages/mmcv/parallel/data_parallel.py in train_step(self, *inputs, **kwargs)
     65 
     66         inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
---> 67         return self.module.train_step(*inputs[0], **kwargs[0])
     68 
     69     def val_step(self, *inputs, **kwargs):

/content/focal_phi_loss_mmsegmentation/mmseg/models/segmentors/base.py in train_step(self, data_batch, optimizer, **kwargs)
    150                 averaging the logs.
    151         """
--> 152         losses = self(**data_batch)
    153         loss, log_vars = self._parse_losses(losses)
    154 

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/usr/local/lib/python3.7/dist-packages/mmcv/runner/fp16_utils.py in new_func(*args, **kwargs)
     82                                 'method of nn.Module')
     83             if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
---> 84                 return old_func(*args, **kwargs)
     85             # get the arg spec of the decorated method
     86             args_info = getfullargspec(old_func)

/content/focal_phi_loss_mmsegmentation/mmseg/models/segmentors/base.py in forward(self, img, img_metas, return_loss, **kwargs)
    120         """
    121         if return_loss:
--> 122             return self.forward_train(img, img_metas, **kwargs)
    123         else:
    124             return self.forward_test(img, img_metas, **kwargs)

/content/focal_phi_loss_mmsegmentation/mmseg/models/segmentors/encoder_decoder.py in forward_train(self, img, img_metas, gt_semantic_seg)
    156 
    157         loss_decode = self._decode_head_forward_train(x, img_metas,
--> 158                                                       gt_semantic_seg)
    159         losses.update(loss_decode)
    160 

/content/focal_phi_loss_mmsegmentation/mmseg/models/segmentors/encoder_decoder.py in _decode_head_forward_train(self, x, img_metas, gt_semantic_seg)
    100         loss_decode = self.decode_head.forward_train(x, img_metas,
    101                                                      gt_semantic_seg,
--> 102                                                      self.train_cfg)
    103 
    104         losses.update(add_prefix(loss_decode, 'decode'))

/content/focal_phi_loss_mmsegmentation/mmseg/models/decode_heads/decode_head.py in forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg)
    185         """
    186         seg_logits = self.forward(inputs)
--> 187         losses = self.losses(seg_logits, gt_semantic_seg)
    188         return losses
    189 

/usr/local/lib/python3.7/dist-packages/mmcv/runner/fp16_utils.py in new_func(*args, **kwargs)
    162                                 'method of nn.Module')
    163             if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
--> 164                 return old_func(*args, **kwargs)
    165             # get the arg spec of the decorated method
    166             args_info = getfullargspec(old_func)

/content/focal_phi_loss_mmsegmentation/mmseg/models/decode_heads/decode_head.py in losses(self, seg_logit, seg_label)
    220             size=seg_label.shape[2:],
    221             mode='bilinear',
--> 222             align_corners=self.align_corners)
    223         if self.sampler is not None:
    224             seg_weight = self.sampler.sample(seg_logit, seg_label)

/content/focal_phi_loss_mmsegmentation/mmseg/ops/wrappers.py in resize(input, size, scale_factor, mode, align_corners, warning)
     27     if isinstance(size, torch.Size):
     28         size = tuple(int(x) for x in size)
---> 29     return F.interpolate(input, size, scale_factor, mode, align_corners)
     30 
     31 

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor)
   3078             if len(size) != dim:
   3079                 raise ValueError('size shape must match input shape. '
-> 3080                                  'Input is {}D, size is {}'.format(dim, len(size)))
   3081             output_size = size
   3082         else:

ValueError: size shape must match input shape. Input is 2D, size is 3

Thank you.

WangZX-0630 commented 2 years ago

If class = [background, road, car] then the label's pixel value on those classes is separately 0, 1 and 2. By the way, the label should only have one channel and the shape is h*w, not 3 channels.