Closed henry189 closed 6 months ago
For example:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from libs.datasets import get_dataset
# Change the following paths to yours
logits_dir = Path("./data/features/cocostuff164k/deeplabv2_resnet101_msc/val2017/logit")
config_path = Path("./configs/cocostuff164k.yaml")
CONFIG = OmegaConf.load(config_path)
dataset = get_dataset(CONFIG.DATASET.NAME)(
root=CONFIG.DATASET.ROOT,
split=CONFIG.DATASET.SPLIT.VAL,
ignore_label=CONFIG.DATASET.IGNORE_LABEL,
mean_bgr=(0, 0, 0),
augment=False,
)
for image_id, image, gt_label in dataset:
logit_path = logits_dir / f"{image_id}.npy"
if logit_path.exists():
logit = torch.from_numpy(np.load(logit_path))[None] # (1,C,H,W)
logit = F.interpolate(logit, size=gt_label.shape, mode="bilinear")
pred_label = logit.argmax(dim=1)[0]
image = image.transpose(1, 2, 0) / 255
fig, ax = plt.subplots(1, 3, figsize=(10, 3))
ax[0].set_title(f"Input ({logit_path.name})")
ax[0].imshow(image[..., ::-1], vmin=0, vmax=1)
ax[1].set_title("Prediction")
ax[1].imshow(pred_label, vmin=0, vmax=CONFIG.DATASET.N_CLASSES - 1)
ax[2].set_title("GT")
ax[2].imshow(gt_label, vmin=0, vmax=CONFIG.DATASET.N_CLASSES - 1)
[a.axis("off") for a in ax]
plt.tight_layout()
plt.savefig(f"{image_id}.png", dpi=300, bbox_inches="tight")
After training and validating the commands, I got the score and the npy file and wanted to know how to visualize the segmented picture of the validation set