TIO-IKIM / CellViT

CellViT: Vision Transformers for Precise Cell Segmentation and Classification
https://doi.org/10.1016/j.media.2024.103143
Other
236 stars 41 forks source link

plot_results() function in "cell_segmentation/inference /inference_cellvit_experiment_pannuke.py" #60

Open PingjunChen opened 2 months ago

PingjunChen commented 2 months ago

Describe the bug Current plot_results() can't handle the attributes of predictions/ground_truth and the shape well. Updated version from my side below:

def plot_results(
    self,
    imgs: Union[torch.Tensor, np.ndarray],
    predictions: dict,
    ground_truth: dict,
    img_names: List,
    num_nuclei_classes: int,
    outdir: Union[Path, str],
    scores: List[List[float]] = None,
) -> None:
    """Generate example plot with image, binary_pred, hv-map and instance map from prediction and ground-truth

    Args:
        imgs (Union[torch.Tensor, np.ndarray]): Images to process
            Shape: (batch_size, 3, H', W')
        predictions (dict): Predictions of models. Keys:
            "nuclei_type_map": Shape: (batch_size, num_nuclei_classes, H', W')
            "nuclei_binary_map": Shape: (batch_size, 2, H', W')
            "hv_map": Shape: (batch_size, 2, H', W')
            "instance_map": Shape: (batch_size, H', W')
        ground_truth (dict): Ground truth values. Keys:
            "nuclei_type_map": Shape: (batch_size, num_nuclei_classes, H', W')
            "nuclei_binary_map": Shape: (batch_size, 2, H', W')
            "hv_map": Shape: (batch_size, 2, H', W')
            "instance_map": Shape: (batch_size, H', W')
        img_names (List): Names of images as list
        num_nuclei_classes (int): Number of total nuclei classes including background
        outdir (Union[Path, str]): Output directory where images should be stored
        scores (List[List[float]], optional): List with scores for each image.
            Each list entry is a list with 3 scores: Dice, Jaccard and bPQ for the image.
            Defaults to None.
    """
    outdir = Path(outdir)
    outdir.mkdir(exist_ok=True, parents=True)

    h = ground_truth.hv_map.shape[2]
    w = ground_truth.hv_map.shape[3]

    # convert to rgb and crop to selection
    sample_images = (
        imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy()
    )  # convert to rgb
    sample_images = cropping_center(sample_images, (h, w), True)

    pred_sample_binary_map = (
        predictions.nuclei_binary_map[:, 1, :, :].detach().cpu().numpy()
    )
    pred_sample_hv_map = predictions.hv_map.detach().cpu().numpy()
    pred_sample_instance_maps = predictions.instance_map.detach().cpu().numpy()
    pred_sample_type_maps = (
        torch.argmax(predictions.nuclei_type_map, dim=1).detach().cpu().numpy()
    )

    gt_sample_binary_map = ground_truth.nuclei_binary_map.detach().cpu().numpy()
    gt_sample_hv_map = ground_truth.hv_map.detach().cpu().numpy()
    gt_sample_instance_map = ground_truth.instance_map.detach().cpu().numpy()
    gt_sample_type_map = (
        torch.argmax(ground_truth.nuclei_type_map, dim=1).detach().cpu().numpy()
    )

    # create colormaps
    hv_cmap = plt.get_cmap("jet")
    binary_cmap = plt.get_cmap("jet")
    instance_map = plt.get_cmap("viridis")
    cell_colors = ["#ffffff", "#ff0000", "#00ff00", "#1e00ff", "#feff00", "#ffbf00"]

    # invert the normalization of the sample images
    transform_settings = self.run_conf["transformations"]
    if "normalize" in transform_settings:
        mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5))
        std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5))
    else:
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
    inv_normalize = transforms.Normalize(
        mean=[-0.5 / mean[0], -0.5 / mean[1], -0.5 / mean[2]],
        std=[1 / std[0], 1 / std[1], 1 / std[2]],
    )
    inv_samples = inv_normalize(torch.tensor(sample_images).permute(0, 3, 1, 2))
    sample_images = inv_samples.permute(0, 2, 3, 1).detach().cpu().numpy()

    for i in range(len(img_names)):
        fig, axs = plt.subplots(figsize=(6, 2), dpi=300)
        placeholder = np.zeros((2 * h, 7 * w, 3))
        # orig image
        placeholder[:h, :w, :3] = sample_images[i]
        placeholder[h : 2 * h, :w, :3] = sample_images[i]
        # binary prediction
        placeholder[:h, w : 2 * w, :3] = rgba2rgb(
            binary_cmap(gt_sample_binary_map[i] * 255)
        )

        placeholder[h : 2 * h, w : 2 * w, :3] = rgba2rgb(
            binary_cmap(pred_sample_binary_map[i] * 255)
        )  
        # hv maps
        placeholder[:h, 2 * w : 3 * w, :3] = rgba2rgb(
            hv_cmap((gt_sample_hv_map[i, 0, :, :] + 1) / 2)
        )
        placeholder[h : 2 * h, 2 * w : 3 * w, :3] = rgba2rgb(
            hv_cmap((pred_sample_hv_map[i, 0, :, :] + 1) / 2)
        )
        placeholder[:h, 3 * w : 4 * w, :3] = rgba2rgb(
            hv_cmap((gt_sample_hv_map[i, 1, :, :] + 1) / 2)
        )
        placeholder[h : 2 * h, 3 * w : 4 * w, :3] = rgba2rgb(
            hv_cmap((pred_sample_hv_map[i, 1, :, :] + 1) / 2)
        )
        # instance_predictions
        placeholder[:h, 4 * w : 5 * w, :3] = rgba2rgb(
            instance_map(
                (gt_sample_instance_map[i] - np.min(gt_sample_instance_map[i]))
                / (
                    np.max(gt_sample_instance_map[i])
                    - np.min(gt_sample_instance_map[i] + 1e-10)
                )
            )
        )
        placeholder[h : 2 * h, 4 * w : 5 * w, :3] = rgba2rgb(
            instance_map(
                (
                    pred_sample_instance_maps[i]
                    - np.min(pred_sample_instance_maps[i])
                )
                / (
                    np.max(pred_sample_instance_maps[i])
                    - np.min(pred_sample_instance_maps[i] + 1e-10)
                )
            )
        )
        # type_predictions
        placeholder[:h, 5 * w : 6 * w, :3] = rgba2rgb(
            binary_cmap(gt_sample_type_map[i] / num_nuclei_classes)
        )
        placeholder[h : 2 * h, 5 * w : 6 * w, :3] = rgba2rgb(
            binary_cmap(pred_sample_type_maps[i] / num_nuclei_classes)
        )

        # contours
        # gt
        gt_contours_polygon = [
            v["contour"] for v in ground_truth.instance_types[i].values()
        ]
        gt_contours_polygon = [
            list(zip(poly[:, 0], poly[:, 1])) for poly in gt_contours_polygon
        ]
        gt_contour_colors_polygon = [
            cell_colors[v["type"]]
            for v in ground_truth.instance_types[i].values()
        ]
        gt_cell_image = Image.fromarray(
            (sample_images[i] * 255).astype(np.uint8)
        ).convert("RGB")
        gt_drawing = ImageDraw.Draw(gt_cell_image)
        add_patch = lambda poly, color: gt_drawing.polygon(
            poly, outline=color, width=2
        )
        [
            add_patch(poly, c)
            for poly, c in zip(gt_contours_polygon, gt_contour_colors_polygon)
        ]
        gt_cell_image.save(outdir / f"raw_gt_{img_names[i]}")
        placeholder[:h, 6 * w : 7 * w, :3] = np.asarray(gt_cell_image) / 255
        # pred
        pred_contours_polygon = [
            v["contour"] for v in predictions.instance_types[i].values()
        ]
        pred_contours_polygon = [
            list(zip(poly[:, 0], poly[:, 1])) for poly in pred_contours_polygon
        ]
        pred_contour_colors_polygon = [
            cell_colors[v["type"]]
            for v in predictions.instance_types[i].values()
        ]
        pred_cell_image = Image.fromarray(
            (sample_images[i] * 255).astype(np.uint8)
        ).convert("RGB")
        pred_drawing = ImageDraw.Draw(pred_cell_image)
        add_patch = lambda poly, color: pred_drawing.polygon(
            poly, outline=color, width=2
        )
        [
            add_patch(poly, c)
            for poly, c in zip(pred_contours_polygon, pred_contour_colors_polygon)
        ]
        pred_cell_image.save(outdir / f"raw_pred_{img_names[i]}")
        placeholder[h : 2 * h, 6 * w : 7 * w, :3] = (
            np.asarray(pred_cell_image) / 255
        )

        # plotting
        axs.imshow(placeholder)
        axs.set_xticks(np.arange(w / 2, 7 * w, w))
        axs.set_xticklabels(
            [
                "Image",
                "Binary-Cells",
                "HV-Map-0",
                "HV-Map-1",
                "Instances",
                "Nuclei-Pred",
                "Countours",
            ],
            fontsize=6,
        )
        axs.xaxis.tick_top()

        axs.set_yticks(np.arange(h / 2, 2 * h, h))
        axs.set_yticklabels(["GT", "Pred."], fontsize=6)
        axs.tick_params(axis="both", which="both", length=0)
        grid_x = np.arange(w, 6 * w, w)
        grid_y = np.arange(h, 2 * h, h)

        for x_seg in grid_x:
            axs.axvline(x_seg, color="black")
        for y_seg in grid_y:
            axs.axhline(y_seg, color="black")

        if scores is not None:
            axs.text(
                20,
                1.85 * h,
                f"Dice: {str(np.round(scores[i][0], 2))}\nJac.: {str(np.round(scores[i][1], 2))}\nbPQ: {str(np.round(scores[i][2], 2))}",
                bbox={"facecolor": "white", "pad": 2, "alpha": 0.5},
                fontsize=4,
            )
        fig.suptitle(f"Patch Predictions for {img_names[i]}")
        fig.tight_layout()
        fig.savefig(outdir / f"pred_{img_names[i]}")
        plt.close()

With this version, generated plots look reasonable on my test images pred_4_3 pred_4_57