open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
28.5k stars 9.28k forks source link

DetLocalVisualizer artifacts in semantic segmentation #11453

Open collinmccarthy opened 5 months ago

collinmccarthy commented 5 months ago

Describe the bug Semantic segmentations are not being accurately visualized with DetLocalVisualizer, leading to the same artifacts as are shown in mmengine issue 741 and below. I don't think the root issue here is with DetLocalVisualizer, but the issue can be fixed by changing the implementation of _draw_sem_seg(). This "fixed" implementation mirrors the _draw_panoptic_seg() method and is also significantly faster.

The current implementation calls Visualizer.draw_binary_masks() and Visualizer.draw_texts() once per label ID in the semantic segmentation. This is causing artifacts around the borders of the image. I tried modifying the interpolation and the implementation of draw_binary_masks to only modify the masked region, but it still causes artifacts outside of a masked region. The solution seems to simply be calling draw_binary_masks() / draw_texts() methods once, drawing all masks and labels at the same time. For the images below I also added draw_polygons() calls to draw outlines as in instance segmentation. This more easily shows the issue is with the borders of the masked regions (and the text boxes).

Visualizer Image 002006 Image 433204 Image 226802
Orig 000000002006_gt_orig 000000433204_gt_orig 000000226802_gt_orig
Fixed 000000002006_gt_fix 000000433204_gt_fix 000000226802_gt_fix
Orig w/ Poly 000000002006_gt_nofix_poly 000000433204_gt_nofix_poly 000000226802_gt_nofix_poly
Fixed w/ Poly 000000002006_gt_fix_poly 000000433204_gt_fix_poly 000000226802_gt_fix_poly

Reproduction

  1. What command or script did you run?

This was discovered as part of a larger project. I created a minimum working example with two files, show_draw_sem_seg_issue.py, which imports the fixed visualizer from fix_local_visualizer.py. These are in show_draw_sem_seg_issue.zip. These assume COCO-Stuff is setup with tools/dataset_converters/coco_stuff164k.py as shown in the docs (and shown below for clarity).

To reproduce you may do the following (also in the docstring of show_draw_sem_seg_issue.py). The images above are output with these calls.

Example:

# Assumes mmdetection was cloned to ~/mmdetection
# Assumes data is symbolically linked to ~/mmdetection/data
python ~/mmdetection/tools/dataset_converters/coco_stuff164k.py \
    ~/mmdetection/data/coco \
    --nproc 8

Example:

# Assumes `show_draw_sem_seg_issue.py` (this script) is copied to:
#    ~/mmdetection/fix_draw_sem_seg/show_draw_sem_seg_issue.py
# Assumes `fix_local_visualizer.py` is copied to same directory:
#    ~/mmdetection/fix_draw_sem_seg/fix_local_visualizer.py
# Assumes data root, e.g. ~/mmdetection/data above, is set to MMDET_DATASETS

cd ~/mmdetection/fix_draw_sem_seg
MMDET_DATASETS=~/mmdetection/data

# Output original GT files to 'gt_orig' dir
python ./show_draw_sem_seg_issue.py --data-root=$MMDET_DATASETS --orig-visualizer --gt-name=gt_orig

# Output fixed GT files to `gt_fix` dir
python ./show_draw_sem_seg_issue.py --data-root=$MMDET_DATASETS --gt-name=gt_fix

# Output fixed GT files with polygon outlines to `gt_fix` dir
python ./show_draw_sem_seg_issue.py --data-root=$MMDET_DATASETS --draw-polygons --gt-name=gt_fix_poly

# Output original GT files with polygon outlines to `gt_nofix_poly` to better see issue
python ./show_draw_sem_seg_issue.py --data-root=$MMDET_DATASETS --no-single-draw-call --draw-polygons --gt-name=gt_nofix_poly
  1. Did you make any modifications on the code or config? Did you understand what you have modified? All code is above and is my own.

  2. What dataset did you use? COCO-Stuff as described above.

Environment

  1. Please run python mmdet/utils/collect_env.py to collect necessary environment information and paste it here.
$ python mmdet/utils/collect_env.py 
sys.platform: linux
Python: 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0]
CUDA available: True
numpy_random_seed: 2147483648
GPU 0: NVIDIA TITAN V
CUDA_HOME: /usr/local/cuda-11.8
NVCC: Cuda compilation tools, release 11.8, V11.8.89
GCC: gcc (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
PyTorch: 2.0.1+cu118
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9
    - Built with CuDNN 8.7
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

TorchVision: 0.15.2+cu118
OpenCV: 4.8.1
MMEngine: 0.10.2
MMDetection: 3.2.0+fe3f809
  1. You may add addition that may be helpful for locating the problem, such as
    • How you installed PyTorch [e.g., pip, conda, source]
    • Other environment variables that may be related (such as $PATH, $LD_LIBRARY_PATH, $PYTHONPATH, etc.)

I setup my environment with a conda/pip/mim. If you need the scripts/instructions I can share them.

Error traceback N/A

Bug fix If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

I can submit my fix with a PR, or someone on your team can if that's easier. The fix is in fix_local_visualizer.py in the zip above, but here is the core of it. This version has a flag to implement the single-draw fix (sem_seg_single_draw_call) and also has a flag to output the polygon outlines (sem_seg_draw_polygons).

class FixDetLocalVisualizer(DetLocalVisualizer):
    """DetLocalVisualizer with new `_draw_sem_seg()` implementation to optionally:
    - Draw all binary masks / texts in one call rather than iteratively (fixes the issue)
    - Draw polygon outlines as in `DetLocalVisualizer._draw_instances()`

    Have to have the 'name' argument (non-keyword) or we get 'must have `name` argument'
    from ManagerMeta init.
    """

    def __init__(
        self,
        name: str,
        *args,
        sem_seg_single_draw_call: bool = True,  # True fixes the issue, False does not
        sem_seg_draw_polygons: bool = False,  # Optional, True helps see the output a bit better
        polygon_line_width: int = 1,
        **kwargs,
    ):
        self.sem_seg_single_draw_call = sem_seg_single_draw_call
        self.sem_seg_draw_polygons = sem_seg_draw_polygons
        self.polygon_line_width = polygon_line_width
        super().__init__(*args, name=name, **kwargs)

    def _draw_sem_seg(
        self,
        image: np.ndarray,
        sem_seg: PixelData,
        classes: Optional[List],
        palette: Optional[List],
    ) -> np.ndarray:
        """Draw semantic seg of GT or prediction.

        Same as DetLocalVisualizer._draw_sem_seg() but uses `self.single_draw_call`:
        - If true, calls draw_binary_masks()/draw_texts() once which is faster /  minimizes artifacts
        - If false, calls draw_binary_masks()/draw_texts() one per mask like in original version

        We were having issues (artifacts) with the original version and single_draw_call=True helps.
        See https://github.com/open-mmlab/mmengine/issues/741

        Args:
            image (np.ndarray): The image to draw.
            sem_seg (:obj:`PixelData`): Data structure for pixel-level
                annotations or predictions.
            classes (list, optional): Input classes for result rendering, as
                the prediction of segmentation model is a segment map with
                label indices, `classes` is a list which includes items
                responding to the label indices. If classes is not defined,
                visualizer will take `cityscapes` classes by default.
                Defaults to None.
            palette (list, optional): Input palette for result rendering, which
                is a list of color palette responding to the classes.
                Defaults to None.

        Returns:
            np.ndarray: the drawn image which channel is RGB.
        """
        sem_seg_data = sem_seg.sem_seg
        if isinstance(sem_seg_data, torch.Tensor):
            sem_seg_data = sem_seg_data.numpy()

        # 0 ~ num_class, the value 0 means background
        ids = np.unique(sem_seg_data)
        ignore_index = sem_seg.metainfo.get("ignore_index", 255)
        ids = ids[ids != ignore_index]

        if "label_names" in sem_seg:
            # open set semseg
            label_names = sem_seg.metainfo["label_names"]
        else:
            label_names = classes

        labels = np.array(ids, dtype=np.int64)
        colors = [palette[label] for label in labels]

        self.set_image(image)

        # draw all binary masks / texts at once
        masks: List[np.array] = []
        mask_colors: List[Tuple] = []
        polygons: List[np.array] = []
        label_texts: List[str] = []
        label_centroids: List[cv2.typing.MatLike] = []
        label_font_sizes: List[int] = []

        # draw semantic masks
        for i, (label, color) in enumerate(zip(labels, colors)):
            mask = sem_seg_data == label

            contours: Optional[List[np.array]] = None
            if self.sem_seg_draw_polygons:  # Added for clarity
                assert mask.shape[0] == 1, "Expected shape (1,H,W)"
                contours, _ = bitmap_to_polygon(mask[0])

            if self.sem_seg_single_draw_call:  # Added (main fix)
                masks.append(mask)
                mask_colors.append(color)
                if contours is not None:
                    polygons.extend(contours)

            else:
                self.draw_binary_masks(mask, colors=[color], alphas=self.alpha)  # Original code
                if contours is not None:  # Added for clarity
                    self.draw_polygons(
                        contours,
                        edge_colors="w",
                        alpha=self.alpha,
                        line_widths=self.polygon_line_width,
                    )

            label_text = label_names[label]
            _, _, stats, centroids = cv2.connectedComponentsWithStats(
                mask[0].astype(np.uint8), connectivity=8
            )
            if stats.shape[0] > 1:
                largest_id = np.argmax(stats[1:, -1]) + 1
                centroids = centroids[largest_id]

                areas = stats[largest_id, -1]
                scales = _get_adaptive_scales(areas)

                if self.sem_seg_single_draw_call:  # Added, main fix
                    assert np.isscalar(scales), "Expected single value for font scaling"
                    label_texts.append(label_text)
                    label_centroids.append(centroids)
                    label_font_sizes.append(int(13 * scales))
                else:  # Original code
                    self.draw_texts(
                        label_text,
                        centroids,
                        colors=(255, 255, 255),
                        font_sizes=int(13 * scales),
                        horizontal_alignments="center",
                        bboxes=[
                            {"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}
                        ],
                    )

        if self.sem_seg_single_draw_call:  # Added, main fix
            masks = np.concatenate(masks, axis=0)  # N x (1,H,W) -> (N,H,W)
            self.draw_binary_masks(masks, colors=mask_colors, alphas=self.alpha)

            if len(polygons) > 0:
                self.draw_polygons(
                    polygons, edge_colors="w", alpha=self.alpha, line_widths=self.polygon_line_width
                )

            label_centroids = np.stack(label_centroids, axis=0)  # N x (2,) -> (N,2)
            bboxes_kwargs = {
                "facecolor": "black",
                "alpha": self.alpha,
                "pad": 0.7,
                "edgecolor": "none",
            }
            self.draw_texts(
                label_texts,
                label_centroids,
                colors=(255, 255, 255),
                font_sizes=label_font_sizes,
                horizontal_alignments="center",
                bboxes=[bboxes_kwargs] * len(label_texts),
            )

        return self.get_image()

If submitting my own PR I have a couple of other suggestions for an updated DetLocalVisualizer which I have implemented in my own codebase:

Thank you for reviewing this.

collinmccarthy commented 5 months ago

I stumbled upon the root of the problem here. I still stand by my solution above but the issue is stemming from Visualizer.get_image()

    @master_only
    def get_image(self) -> np.ndarray:
        """Get the drawn image. The format is RGB.

        Returns:
            np.ndarray: the drawn image which channel is RGB.
        """
        assert self._image is not None, 'Please set image using `set_image`'
        return img_from_canvas(self.fig_save_canvas)  # type: ignore

It looks like the img_from_canvas() method is not pixel-wise accurate. It is producing small off-by-one errors sometimes, likely due to the underlying rendering method. I'm not sure how the mechanics of the Visualizer class work, but I think the point is we need to make sure we only call this method when necessary. Since it's called from draw_binary_masks() that means calling draw_binary_masks() one time is much much better than calling it many times (from a pixel-wise accuracy point of view).