kazuto1011 / deeplab-pytorch

PyTorch re-implementation of DeepLab v2 on COCO-Stuff / PASCAL VOC datasets
MIT License
1.09k stars 282 forks source link

Visualize the segmented picture of the validation set #117

Closed henry189 closed 6 months ago

henry189 commented 6 months ago

After training and validating the commands, I got the score and the npy file and wanted to know how to visualize the segmented picture of the validation set

kazuto1011 commented 6 months ago

For example:

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf

from libs.datasets import get_dataset

# Change the following paths to yours
logits_dir = Path("./data/features/cocostuff164k/deeplabv2_resnet101_msc/val2017/logit")
config_path = Path("./configs/cocostuff164k.yaml")
CONFIG = OmegaConf.load(config_path)

dataset = get_dataset(CONFIG.DATASET.NAME)(
    root=CONFIG.DATASET.ROOT,
    split=CONFIG.DATASET.SPLIT.VAL,
    ignore_label=CONFIG.DATASET.IGNORE_LABEL,
    mean_bgr=(0, 0, 0),
    augment=False,
)

for image_id, image, gt_label in dataset:
    logit_path = logits_dir / f"{image_id}.npy"

    if logit_path.exists():
        logit = torch.from_numpy(np.load(logit_path))[None]  # (1,C,H,W)
        logit = F.interpolate(logit, size=gt_label.shape, mode="bilinear")
        pred_label = logit.argmax(dim=1)[0]
        image = image.transpose(1, 2, 0) / 255
        fig, ax = plt.subplots(1, 3, figsize=(10, 3))
        ax[0].set_title(f"Input ({logit_path.name})")
        ax[0].imshow(image[..., ::-1], vmin=0, vmax=1)
        ax[1].set_title("Prediction")
        ax[1].imshow(pred_label, vmin=0, vmax=CONFIG.DATASET.N_CLASSES - 1)
        ax[2].set_title("GT")
        ax[2].imshow(gt_label, vmin=0, vmax=CONFIG.DATASET.N_CLASSES - 1)
        [a.axis("off") for a in ax]
        plt.tight_layout()
        plt.savefig(f"{image_id}.png", dpi=300, bbox_inches="tight")

000000000139