qubvel / segmentation_models

Segmentation models with pretrained backbones. Keras and TensorFlow Keras.
MIT License
4.74k stars 1.03k forks source link

[OPEN QUESTION] How to evaluate segmentation model with empty masks ? #407

Open LAUBENicolas opened 4 years ago

LAUBENicolas commented 4 years ago

I created multiple models that are able to do binary segmentation on images and I would like to evaluate them.

I would like to know for each model how good he is:

Actually, I'm using these scores for that purpose:

A good intuitive explanation of these scores can be found @ https://www.jeremyjordan.me/evaluating-image-segmentation-models/

The scores that I listed above are well-suited to verify that the model predicts correctly segmented objects with a precise border. However, my test dataset contains also images with no objects on it and I would like to evaluate how good my model is to know that there is nothing to predict.

For the moment, I get an IoU and a Dice score of 0 when there is no object to find mask and my model correctly predicted an empty mask which is wrong.

Here is my code:

class SegmentationEvaluator():

    def __init__(self, model: SegmentationModel, dataset: Dataset2D, batch_size: int = 10, iou_thresholds: List[float] = [0.1, 0.5, 0.7, 0.8], dice_thresholds: List[float] = [0.1, 0.5, 0.7, 0.8]):
        self.model = model
        self.dataset = dataset
        self.batch_size = batch_size
        self.batch_loader = BatchLoader(self.dataset, batch_size=batch_size, shuffle=False)
        self.ious = []
        self.dices = []
        self.accuracies = []
        self.true_positive_iou = [0 for i in range(len(iou_thresholds))]
        self.false_positive_iou = [0 for i in range(len(iou_thresholds))]
        self.false_negative_iou = [0 for i in range(len(iou_thresholds))]
        self.true_positive_dice = [0 for i in range(len(dice_thresholds))]
        self.false_positive_dice = [0 for i in range(len(dice_thresholds))]
        self.false_negative_dice = [0 for i in range(len(dice_thresholds))]
        self.iou_thresholds = iou_thresholds
        self.dice_thresholds = dice_thresholds

    def evaluate(self, metrics: List[str] = ["iou", "accuracy", "dice", "precision_iou", "recall_iou", "precision_dice", \

                 "recall_dice", "tp_iou", "fp_iou", "fn_iou"]) -> List[Any]:
        # add metrics iou or dice

        for batch in tqdm(self.batch_loader):
            images, masks = batch
            predicted_masks = self.model.predict_masks(images)
            for metric in ["iou", "accuracy", "dice"]:
                # we don't need to go through all metrics
                for true_mask, predicted_mask in zip(masks, predicted_masks):
                    if metric == "iou":
                        iou = get_iou(true_mask, predicted_mask)
                        self.update_tp_tn_fp_fn_iou(iou)
                        self.ious.append(iou)
                    elif metric == "dice":
                        dice = get_dice(true_mask, predicted_mask)
                        self.update_tp_tn_fp_fn_dice(dice)
                        self.dices.append(dice)
                    elif metric == "accuracy":
                        self.accuracies.append(get_accuracy(true_mask, predicted_mask))

        ### Prepare the results to display ###
        results = defaultdict(float)
        # IOU MEDIAN + MEAN
        if "iou" in metrics:
            results["mean_iou:"] = self.get_mean_iou()
            results["median iou:"] = self.get_median_iou()
        # DICE MEDIAN + MEAN
        if "dice" in metrics:
            results["mean dice:"] = self.get_mean_dice()
            results["median dice:"] = self.get_median_dice()
        # PRECISION IOU
        if "precision_iou" in metrics:
            precision_iou = self.get_precision_iou()
            for threshold in self.iou_thresholds:
                results[f"precision_iou: {threshold}"] = precision_iou[threshold]
        # PRECISION DICE
        if "precision_dice" in metrics:
            precision_dice = self.get_precision_dice()
            for threshold in self.dice_thresholds:
                results[f"precision_dice: {threshold}"] = precision_dice[threshold]
        # RECALL IOU
        if "recall_iou" in metrics:
            recall_iou = self.get_recall_iou()
            for threshold in self.iou_thresholds:
                results[f"recall_iou: {threshold}"] = recall_iou[threshold]
        # RECALL DICE
        if "recall_dice" in metrics:
            recall_dice = self.get_recall_dice()
            for threshold in self.dice_thresholds:
                results[f"recall_dice: {threshold}"] = recall_dice[threshold]
        # ACCURACY
        if "accuracy" in metrics:
            results["accuracy"] = self.get_accuracy()
        # TRUE POSITIVES
        if "tp_iou" in metrics:
            for i, threshold in enumerate(self.iou_thresholds):
                results[f"tp_iou {threshold}:"] = self.true_positive_iou[i]
        # FALSE NEGATIVES
        if "fn_iou" in metrics:
            for i, threshold in enumerate(self.iou_thresholds):
                results[f"fn_iou {threshold}:"] = self.false_negative_iou[i]
        # FALSE POSITIVES
        if "fp_iou" in metrics:
            for i, threshold in enumerate(self.iou_thresholds):
                results[f"fp_iou {threshold}:"] = self.false_positive_iou[i]
        return results.items()

    def get_accuracy(self):
        """Computes the pixel accuracy on the test dataset."""
        return sum(self.accuracies) / len(self.accuracies)

    def get_median_dice(self) -> float:
        """Computes the median dice on test dataset."""
        return statistics.median(self.dices)

    def get_mean_dice(self) -> float:
        """Computes the mean dice on test dataset."""
        return sum(self.dices) / len(self.dices)

    def get_median_iou(self) -> float:
        """Computes the median iou on test dataset."""
        return statistics.median(self.ious)

    def get_mean_iou(self) -> float:
        """Computes the mean iou on test dataset."""
        return sum(self.ious) / len(self.ious)

    def update_tp_tn_fp_fn_dice(self, dice: float) -> None:
        for i, threshold in enumerate(self.dice_thresholds):
            if dice >= threshold:
                self.true_positive_dice[i] += 1
            else:
                self.false_negative_dice[i] += 1
                self.false_positive_dice[i] += 1

    def update_tp_tn_fp_fn_iou(self, iou: float) -> None:
        for i, threshold in enumerate(self.iou_thresholds):
            if iou >= threshold:
                self.true_positive_iou[i] += 1
            else:
                self.false_negative_iou[i] += 1
                self.false_positive_iou[i] += 1

    def __show_result(self, image: NDArray[(Any, Any), float], true_mask: NDArray[(Any, Any), int]) -> None:
        predicted_mask = self.model.predict_masks([image])[0]
        show_image_color(image=image, true_mask=true_mask, predicted_mask=predicted_mask)
        image = image.astype("uint8")
        show_mask_on_image(image=image, mask=true_mask)
        print(f"IOU: {get_iou(true_mask, predicted_mask)}")
        print(f"DICE: {get_dice(true_mask, predicted_mask)}")

    def show_results(self, display: int) -> None:
        j = 1
        for batch in self.batch_loader:
            images, masks = batch
            for image, mask in zip(images, masks):
                if j <= display:
                   self.__show_result(image, mask)
                    j += 1
                else: return None

    def get_precision_iou(self) -> float:
        precision_iou = defaultdict(float)
        for i, iou_threshold in enumerate(self.iou_thresholds):
            precision_iou[iou_threshold] = self.true_positive_iou[i] / (self.true_positive_iou[i] + self.false_positive_iou[i])
        return precision_iou

    def get_recall_iou(self) -> float:
        recall_iou = defaultdict(float)
        for i, iou_threshold in enumerate(self.iou_thresholds):
            recall_iou[iou_threshold] = self.true_positive_iou[i] / (self.true_positive_iou[i] + self.false_negative_iou[i])
        return recall_iou

    def get_precision_dice(self) -> float:
        precision_dice = defaultdict(float)
        for i, dice_threshold in enumerate(self.dice_thresholds):
            precision_dice[dice_threshold] = self.true_positive_dice[i] / (self.true_positive_dice[i] + self.false_positive_dice[i])
        return precision_dice

    def get_recall_dice(self) -> float:
        recall_dice = defaultdict(float)
        for i, dice_threshold in enumerate(self.dice_thresholds):
            recall_dice[dice_threshold] = self.true_positive_dice[i] / (self.true_positive_dice[i] + self.false_negative_dice[i])
        return recall_dice

def get_accuracy(true_mask: NDArray[(Any, Any), float], predicted_mask: NDArray[(Any, Any), float]) -> float:
    true_positive, true_negative, false_positive, false_negative = get_pixel_tp_tn_fp_fn(true_mask, predicted_mask)
    return (true_positive + true_negative) / (true_negative + true_positive + false_negative + false_positive)

def get_pixel_tp_tn_fp_fn(true_mask: NDArray[(Any, Any), float], predicted_mask: NDArray[(Any, Any), float]) -> Tuple[int, int, int, int]:
    # true positive represents a pixel that is correctly predicted to belong to the tumor class
    true_positive = np.sum(np.logical_and(true_mask, predicted_mask))
    # true negative represents a pixel that is correctly identified as non-tumoral
    inversed_true_mask = 1 - true_mask
    inversed_predicted_mask = 1 - predicted_mask
    true_negative = np.sum(np.logical_and(inversed_true_mask, inversed_predicted_mask))
    # false positive represents a pixel that is wrongly predicted to belong to the tumor class
    diff_mask = true_mask - predicted_mask
    diff_mask_copy = np.copy(diff_mask)
    diff_mask[diff_mask == 1] = 0
    false_positive = - np.sum(diff_mask)
    # false negative represents a pixel that should have been predicted as tumoral but wasn't
    diff_mask_copy[diff_mask_copy == -1] = 0
    false_negative = np.sum(diff_mask_copy)
    return true_positive, true_negative, false_positive, false_negative

def get_iou(true_mask: NDArray[(Any, Any), float], predicted_mask: NDArray[(Any, Any), float]) -> float:   
    """
    Computes the iou score for binary segmentation.
    """
    #true_mask = true_mask#.astype(np.bool)
    # print(true_mask)
    # predicted_mask = predicted_mask#.astype(np.bool)
    intersection = np.logical_and(true_mask, predicted_mask).astype("uint8")
    union = np.logical_or(true_mask, predicted_mask).astype("uint8")
    return np.sum(intersection) / np.sum(union)

def get_dice(true_mask: NDArray[(Any, Any), float], predicted_mask: NDArray[(Any, Any), float]) -> float:
    """
    Computes the dice score for binary segmentation.
    """
    true_mask = true_mask.astype(np.bool)
    predicted_mask = predicted_mask.astype(np.bool)
    masks_sum = np.sum(true_mask) + np.sum(predicted_mask)
    return 2 * np.sum(np.logical_and(true_mask, predicted_mask)) / masks_sum

I tried also to inverse the masks when the masks are empty but it gives to high iou and dice scores even if the predicted mask isn't completly empty.

def get_iou(true_mask: NDArray[(Any, Any), float], predicted_mask: NDArray[(Any, Any), float]) -> float:
    """
    Computes the iou score for binary segmentation.
    """
    #true_mask = true_mask#.astype(np.bool)
    # print(true_mask)
    # predicted_mask = predicted_mask#.astype(np.bool)
    if np.sum(true_mask) == 0:
        true_mask = 1 - true_mask
        predicted_mask = 1 - predicted_mask
    intersection = np.logical_and(true_mask, predicted_mask).astype("uint8")
    union = np.logical_or(true_mask, predicted_mask).astype("uint8")
    return np.sum(intersection) / np.sum(union)

def get_dice(true_mask: NDArray[(Any, Any), float], predicted_mask: NDArray[(Any, Any), float]) -> float:
    """
    Computes the dice score for binary segmentation.
    """
    if np.sum(true_mask) == 0:
        true_mask = 1 - true_mask
        predicted_mask = 1 - predicted_mask
    true_mask = true_mask.astype(np.bool)
    predicted_mask = predicted_mask.astype(np.bool)
    masks_sum = np.sum(true_mask) + np.sum(predicted_mask)
    return 2 * np.sum(np.logical_and(true_mask, predicted_mask)) / masks_sum

So, my questions are:

Thank you in advance for your help and ideas !

MuhammadKhalid3975 commented 1 year ago

I think when both ground truth mask and prediction mask are empty then the metric should evaluate to Null and later we can exclude the nulls while reducing to a single value metric.