Describe the bug
Current plot_results() can't handle the attributes of predictions/ground_truth and the shape well. Updated version from my side below:
def plot_results(
self,
imgs: Union[torch.Tensor, np.ndarray],
predictions: dict,
ground_truth: dict,
img_names: List,
num_nuclei_classes: int,
outdir: Union[Path, str],
scores: List[List[float]] = None,
) -> None:
"""Generate example plot with image, binary_pred, hv-map and instance map from prediction and ground-truth
Args:
imgs (Union[torch.Tensor, np.ndarray]): Images to process
Shape: (batch_size, 3, H', W')
predictions (dict): Predictions of models. Keys:
"nuclei_type_map": Shape: (batch_size, num_nuclei_classes, H', W')
"nuclei_binary_map": Shape: (batch_size, 2, H', W')
"hv_map": Shape: (batch_size, 2, H', W')
"instance_map": Shape: (batch_size, H', W')
ground_truth (dict): Ground truth values. Keys:
"nuclei_type_map": Shape: (batch_size, num_nuclei_classes, H', W')
"nuclei_binary_map": Shape: (batch_size, 2, H', W')
"hv_map": Shape: (batch_size, 2, H', W')
"instance_map": Shape: (batch_size, H', W')
img_names (List): Names of images as list
num_nuclei_classes (int): Number of total nuclei classes including background
outdir (Union[Path, str]): Output directory where images should be stored
scores (List[List[float]], optional): List with scores for each image.
Each list entry is a list with 3 scores: Dice, Jaccard and bPQ for the image.
Defaults to None.
"""
outdir = Path(outdir)
outdir.mkdir(exist_ok=True, parents=True)
h = ground_truth.hv_map.shape[2]
w = ground_truth.hv_map.shape[3]
# convert to rgb and crop to selection
sample_images = (
imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy()
) # convert to rgb
sample_images = cropping_center(sample_images, (h, w), True)
pred_sample_binary_map = (
predictions.nuclei_binary_map[:, 1, :, :].detach().cpu().numpy()
)
pred_sample_hv_map = predictions.hv_map.detach().cpu().numpy()
pred_sample_instance_maps = predictions.instance_map.detach().cpu().numpy()
pred_sample_type_maps = (
torch.argmax(predictions.nuclei_type_map, dim=1).detach().cpu().numpy()
)
gt_sample_binary_map = ground_truth.nuclei_binary_map.detach().cpu().numpy()
gt_sample_hv_map = ground_truth.hv_map.detach().cpu().numpy()
gt_sample_instance_map = ground_truth.instance_map.detach().cpu().numpy()
gt_sample_type_map = (
torch.argmax(ground_truth.nuclei_type_map, dim=1).detach().cpu().numpy()
)
# create colormaps
hv_cmap = plt.get_cmap("jet")
binary_cmap = plt.get_cmap("jet")
instance_map = plt.get_cmap("viridis")
cell_colors = ["#ffffff", "#ff0000", "#00ff00", "#1e00ff", "#feff00", "#ffbf00"]
# invert the normalization of the sample images
transform_settings = self.run_conf["transformations"]
if "normalize" in transform_settings:
mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5))
std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5))
else:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
inv_normalize = transforms.Normalize(
mean=[-0.5 / mean[0], -0.5 / mean[1], -0.5 / mean[2]],
std=[1 / std[0], 1 / std[1], 1 / std[2]],
)
inv_samples = inv_normalize(torch.tensor(sample_images).permute(0, 3, 1, 2))
sample_images = inv_samples.permute(0, 2, 3, 1).detach().cpu().numpy()
for i in range(len(img_names)):
fig, axs = plt.subplots(figsize=(6, 2), dpi=300)
placeholder = np.zeros((2 * h, 7 * w, 3))
# orig image
placeholder[:h, :w, :3] = sample_images[i]
placeholder[h : 2 * h, :w, :3] = sample_images[i]
# binary prediction
placeholder[:h, w : 2 * w, :3] = rgba2rgb(
binary_cmap(gt_sample_binary_map[i] * 255)
)
placeholder[h : 2 * h, w : 2 * w, :3] = rgba2rgb(
binary_cmap(pred_sample_binary_map[i] * 255)
)
# hv maps
placeholder[:h, 2 * w : 3 * w, :3] = rgba2rgb(
hv_cmap((gt_sample_hv_map[i, 0, :, :] + 1) / 2)
)
placeholder[h : 2 * h, 2 * w : 3 * w, :3] = rgba2rgb(
hv_cmap((pred_sample_hv_map[i, 0, :, :] + 1) / 2)
)
placeholder[:h, 3 * w : 4 * w, :3] = rgba2rgb(
hv_cmap((gt_sample_hv_map[i, 1, :, :] + 1) / 2)
)
placeholder[h : 2 * h, 3 * w : 4 * w, :3] = rgba2rgb(
hv_cmap((pred_sample_hv_map[i, 1, :, :] + 1) / 2)
)
# instance_predictions
placeholder[:h, 4 * w : 5 * w, :3] = rgba2rgb(
instance_map(
(gt_sample_instance_map[i] - np.min(gt_sample_instance_map[i]))
/ (
np.max(gt_sample_instance_map[i])
- np.min(gt_sample_instance_map[i] + 1e-10)
)
)
)
placeholder[h : 2 * h, 4 * w : 5 * w, :3] = rgba2rgb(
instance_map(
(
pred_sample_instance_maps[i]
- np.min(pred_sample_instance_maps[i])
)
/ (
np.max(pred_sample_instance_maps[i])
- np.min(pred_sample_instance_maps[i] + 1e-10)
)
)
)
# type_predictions
placeholder[:h, 5 * w : 6 * w, :3] = rgba2rgb(
binary_cmap(gt_sample_type_map[i] / num_nuclei_classes)
)
placeholder[h : 2 * h, 5 * w : 6 * w, :3] = rgba2rgb(
binary_cmap(pred_sample_type_maps[i] / num_nuclei_classes)
)
# contours
# gt
gt_contours_polygon = [
v["contour"] for v in ground_truth.instance_types[i].values()
]
gt_contours_polygon = [
list(zip(poly[:, 0], poly[:, 1])) for poly in gt_contours_polygon
]
gt_contour_colors_polygon = [
cell_colors[v["type"]]
for v in ground_truth.instance_types[i].values()
]
gt_cell_image = Image.fromarray(
(sample_images[i] * 255).astype(np.uint8)
).convert("RGB")
gt_drawing = ImageDraw.Draw(gt_cell_image)
add_patch = lambda poly, color: gt_drawing.polygon(
poly, outline=color, width=2
)
[
add_patch(poly, c)
for poly, c in zip(gt_contours_polygon, gt_contour_colors_polygon)
]
gt_cell_image.save(outdir / f"raw_gt_{img_names[i]}")
placeholder[:h, 6 * w : 7 * w, :3] = np.asarray(gt_cell_image) / 255
# pred
pred_contours_polygon = [
v["contour"] for v in predictions.instance_types[i].values()
]
pred_contours_polygon = [
list(zip(poly[:, 0], poly[:, 1])) for poly in pred_contours_polygon
]
pred_contour_colors_polygon = [
cell_colors[v["type"]]
for v in predictions.instance_types[i].values()
]
pred_cell_image = Image.fromarray(
(sample_images[i] * 255).astype(np.uint8)
).convert("RGB")
pred_drawing = ImageDraw.Draw(pred_cell_image)
add_patch = lambda poly, color: pred_drawing.polygon(
poly, outline=color, width=2
)
[
add_patch(poly, c)
for poly, c in zip(pred_contours_polygon, pred_contour_colors_polygon)
]
pred_cell_image.save(outdir / f"raw_pred_{img_names[i]}")
placeholder[h : 2 * h, 6 * w : 7 * w, :3] = (
np.asarray(pred_cell_image) / 255
)
# plotting
axs.imshow(placeholder)
axs.set_xticks(np.arange(w / 2, 7 * w, w))
axs.set_xticklabels(
[
"Image",
"Binary-Cells",
"HV-Map-0",
"HV-Map-1",
"Instances",
"Nuclei-Pred",
"Countours",
],
fontsize=6,
)
axs.xaxis.tick_top()
axs.set_yticks(np.arange(h / 2, 2 * h, h))
axs.set_yticklabels(["GT", "Pred."], fontsize=6)
axs.tick_params(axis="both", which="both", length=0)
grid_x = np.arange(w, 6 * w, w)
grid_y = np.arange(h, 2 * h, h)
for x_seg in grid_x:
axs.axvline(x_seg, color="black")
for y_seg in grid_y:
axs.axhline(y_seg, color="black")
if scores is not None:
axs.text(
20,
1.85 * h,
f"Dice: {str(np.round(scores[i][0], 2))}\nJac.: {str(np.round(scores[i][1], 2))}\nbPQ: {str(np.round(scores[i][2], 2))}",
bbox={"facecolor": "white", "pad": 2, "alpha": 0.5},
fontsize=4,
)
fig.suptitle(f"Patch Predictions for {img_names[i]}")
fig.tight_layout()
fig.savefig(outdir / f"pred_{img_names[i]}")
plt.close()
With this version, generated plots look reasonable on my test images
Describe the bug Current plot_results() can't handle the attributes of predictions/ground_truth and the shape well. Updated version from my side below:
With this version, generated plots look reasonable on my test images