vicoslab / mixed-segdec-net-comind2021

Official PyTorch implementation for "Mixed supervision for surface-defect detection: from weakly to fully supervised learning"
Other
290 stars 89 forks source link

Modified the distance transform function to fix the problem of ignoring small defects. #14

Closed Yangly0 closed 3 years ago

Yangly0 commented 3 years ago

Result

image

Code

import os
import cv2
import numpy as np
from scipy.ndimage.morphology import distance_transform_edt
import matplotlib.pyplot as plt

def distance_transform(mask: np.ndarray, max_val: float, p: float) -> np.ndarray:
    dst_trf = distance_transform_edt(mask)
    if dst_trf.max() > 0:
        dst_trf = (dst_trf / dst_trf.max())
        dst_trf = (dst_trf ** p) * max_val
    dst_trf[mask == 0] = 1
    return np.array(dst_trf, dtype=np.float32)

def distance_transform_new(mask, max_val, p): 
    h, w = mask.shape[:2]
    dst_trf = np.zeros((h, w))
    num_labels, labels = cv2.connectedComponents(mask, connectivity=8)
    for idx in range(1, num_labels):
        mask_roi= np.zeros((h, w))
        k = labels == idx
        mask_roi[k] = 255
        dst_trf_roi = distance_transform_edt(mask_roi)
        if dst_trf_roi.max() > 0:
            dst_trf_roi = (dst_trf_roi / dst_trf_roi.max())
            dst_trf_roi = (dst_trf_roi ** p) * max_val
        dst_trf += dst_trf_roi

    dst_trf[mask == 0] = 1
    return np.array(dst_trf, dtype=np.float32)

image_name = './KSDD2/train/20922.png'
mask_name = './KSDD2/train/20922_GT.png'

img = cv2.imread(image_name)
img = cv2.resize(img, dsize=(224, 600))

lbl = cv2.imread(mask_name, cv2.IMREAD_GRAYSCALE)
lbl = cv2.resize(lbl, dsize=(224, 600))

dilate_lbl = cv2.dilate(lbl, np.ones((7, 7)))

distance_transform_lbl = distance_transform(dilate_lbl, max_val=3.0, p=2)
distance_transform_lbl_new = distance_transform_new(dilate_lbl, max_val=3.0, p=2)

plt.figure(figsize=(10, 6))
plt.suptitle("KSDD2 train 20922.png and 20922_GT.png")
plt.subplot(151)
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title("img")
plt.subplot(152)
plt.imshow(lbl, cmap='gray')
plt.title("lbl")
plt.subplot(153)
plt.imshow(dilate_lbl, cmap='gray')
plt.title("dilate7_lbl")
plt.subplot(154)
plt.imshow(distance_transform_lbl, cmap='gray')
plt.title("DT_lbl")
plt.subplot(155)
plt.imshow(distance_transform_lbl_new, cmap='gray')
plt.title("DT_lbl_new")
plt.tight_layout()
plt.show()