sandylaker / saliency-metrics

An Open-source Framework for Benchmarking Explanation Methods in Computer Vision
https://saliency-metrics.readthedocs.io
MIT License
1 stars 3 forks source link

[Feature Request] Implementation of Insertion-Deletion #15

Open sandylaker opened 2 years ago

sandylaker commented 2 years ago

Implementation of Insertion-Deletion

As indicated by the name, Insertion-Deletion contains two experiments:

ProgressivePerturbation

To implement these two experiments, we first need to implement ProgressivePerturbation class which looks like this:

class ProgressivePerturbation:

    def __init__(
        self, 
        input_tensor: torch.Tensor, 
        replace_tensor: torch.Tensor, 
        sorted_inds: Tuple[torch.Tensor, torch.Tensor], 
    ) -> None:
        self._current_tensor = input_tensor.clone()
        self._replace_tensor = replace_tensor
        self._row_inds, self._col_inds = sorted_inds
        ...

    def _perturb_by_inds(self, row_inds: torch.Tensor, col_inds: torch.Tensor) -> None:
        self._current_tensor[..., row_inds, col_inds] = self.replace_tensor[..., row_inds, col_inds]

    @property
    def current_tensor(self) -> torch.Tensor:
        return self._current_tensor

    def perturb(self, forward_batch_size: int = 128, perturb_step_size: int = 10) -> Iterator[torch.Tensor]:
        ...

InsertionDeletionResult


class InsertionDeletionResult(SerilizableResult):

    def __init__(
        self, 
        summarized: bool = False, 
        ...
    ) -> None:

        self.summarized = summarized
        ...

    def add_single_result(self, single_result: Dict) -> None:
        ...

    def dump(self, file_path: str) -> None:
        ...

The single_result is a dictionary containing following fields:

The flag summarized specifies whether to dump all the raw single results into one JSON file, or compute the mean and std of the ins_auc and del_auc and dump only the statistics. If summarized = True, a dict containing following fields should be dumped to a JSON file:

I.e. a JSON file like this:

{"mean_ins_auc": 0.5, "std_inds_auc": 0.2, "mean_del_auc": 0.4, "std_del_auc": 0.3, "num_samples": 100}

Otherwise, the JSON file is like this:

[
    {"del_scores": [0.8, 0.7, 0.6], "ins_scores": [0.1, 0.2, 0.3], "img_path": "path/to/img0.JPEG", "ins_auc": 0.5, "del_auc": 0.4}, 
    {"del_scores": [0.82, 0.71, 0.65], "ins_scores": [0.12, 0.27, 0.38], "img_path": "path/to/img1.JPEG", "ins_auc": 0.52, "del_auc": 0.45},  
    {"del_scores": [0.81, 0.77, 0.69], "ins_scores": [0.12, 0.21, 0.34], "img_path": "path/to/img2.JPEG", "ins_auc": 0.58, "del_auc": 0.47}, 
]

InsertionDeletion

from torchvision.transforms import GaussianBlur

class InsertionDeletion(ReInferenceMetric):

    def __init__(
        self, 
        classifier_cfg: Dict, 
        forward_batch_size: int = 128, 
        perturb_step_size = 10, 
        sigma: float = 5.0,
        summarized: bool = False,
        ...
    ) -> None:

        self._result = InsertionDeletionResult(summarized, ...)

        self.classifier = ...
        self.gaussian_blur = GaussianBlur(int(2 * sigma - 1), sigma) 
        self.forward_batch_size = forward_batch_size
        self.perturb_step_size = perturb_step_size
        ...

    def evaluate(self, img: torch.Tensor, smap: torch.Tensor, target: int) -> Dict:
        num_pixels = torch.numel(smap)
        _, inds = torch.topk(smap.flatten(), num_pixels)
        row_inds, col_inds = (torch.tensor(x) for x in np.unravel_index(inds.numpy(), smap.size()))

        # deletion
        del_perturbation = ProgressivePerturbation(...)
        del_scores: List[np.ndarray] = []
        with torch.no_grad():
            for ... in del_perturbation.perturb(...):
                scores = torch.softmax(self.classifier(...), -1)[:, target]
                    del_scores.append(scores.cpu().numpy())

        del_scores: np.ndarray = np.concatenate(del_scores, 0)
        # compute AUC
        ...

        blurred_img = self.gaussian_blur(img)
        ins_perturbation = ProgressivePerturbation(blurred_img, img)
        ins_scores: List[np.ndarray] = []
        with torch.no_grad():
            for ... in ins_perturbation.perturb(...):
                scores = torch.softmax(self.classifier(...), -1)[:, target]
                ins_scores.append(scores.cpu().numpy())

        ins_scores: np.ndarray = np.concatenate(ins_scores, 0)
        # compute AUC

        single_result = dict(...)
        return single_result

    def update(self, single_result: Dict) -> None:
        ...
rjagtani commented 2 years ago

Thanks for this. Can you explain why we are using a forward batch here? eg: for deletion, the aim is to remove perturb-step-size pixels from the image (sorted by saliency map importance) until all pixels are perturbed and calculate scores at each stage of perturbation, why do we need to process it in batches of size 'forward batch size'?

sandylaker commented 2 years ago

For a single image, we progressively replace the pixels with values from another pre-defined tensor. At each step, we place perturb_step_size e.g. 10 pixels. If an image with spatial size 224 by 224 is given, then it takes over 4000 steps. If we forward perturbed tensors individually, then it is too slow. Instead, we would batch these perturbed tensors and forward them batch by batch to fully utilize the computation power of the GPU.

rjagtani commented 2 years ago

Thanks, it's clear now :).

rjagtani commented 2 years ago

What is the use of the method update in InsertionDeletion class? Is it different from add_single_result method in InsertionDeletionResult class?

sandylaker commented 2 years ago

@rjagtani It is basically calling self._result.add_single_result inside the function. Note that self._result is private. So we cannot call add_single_result outside the InsertionDeletion class. Therefore, we wrap it a little bit. In the end it is like this:

ins_del = InsertionDeletion(...)
single_result = ins_del.evaluate(...)
# maybe some post-processing of the single_result, e.g. add img_path to the dict
...
ins_del.update(single_result)
rjagtani commented 2 years ago

I'm writing tests for Insertion Deletion Metric and one of the problems I'm facing is that I don't have a constant output to test my code against. Here is the test - assert expected_output == ins_del.evaluate(...) but I'm getting different prediction scores for the same image (because the weights are initialized randomly and pretrained is set to False). Any ideas would be appreciated. I'd also like to discuss some changes to the code and tests that I have added. I have published the code on my github - the name of the branch is 'id3'.

sandylaker commented 2 years ago

I'm writing tests for Insertion Deletion Metric and one of the problems I'm facing is that I don't have a constant output to test my code against. Here is the test - assert expected_output == ins_del.evaluate(...) but I'm getting different prediction scores for the same image (because the weights are initialized randomly and pretrained is set to False). Any ideas would be appreciated. I'd also like to discuss some changes to the code and tests that I have added. I have published the code on my github - the name of the branch is 'id3'.

Please do not use any pre-trained/ randomly initialized torchvision or timm models in your tests. They are too heavy. Instead, create a dummy and shallow CNN on your own, and initialize the weights to certain constants.

In addition, it is not necessary to test the numeric equality in all cases. Sometimes testing the object types or array shapes are sufficient.

Regarding the updated code, please send a Draft PR so that I can comment and send suggestions.

rjagtani commented 2 years ago

I'm writing tests for Insertion Deletion Metric and one of the problems I'm facing is that I don't have a constant output to test my code against. Here is the test - assert expected_output == ins_del.evaluate(...) but I'm getting different prediction scores for the same image (because the weights are initialized randomly and pretrained is set to False). Any ideas would be appreciated. I'd also like to discuss some changes to the code and tests that I have added. I have published the code on my github - the name of the branch is 'id3'.

Please do not use any pre-trained/ randomly initialized torchvision or timm models in your tests. They are too heavy. Instead, create a dummy and shallow CNN on your own, and initialize the weights to certain constants.

In addition, it is not necessary to test the numeric equality in all cases. Sometimes testing the object types or array shapes are sufficient.

Regarding the updated code, please send a Draft PR so that I can comment and send suggestions.

Thanks, I'll make these changes and send a Draft PR

rjagtani commented 2 years ago

I get this error when I run 'mypy saliency_metrics' - Signature of "evaluate" incompatible with supertype "ReInferenceMetric". Seems like the issue has been discussed here and would require changes to ReInferenceMetric https://stackoverflow.com/questions/51003146/python-3-6-signature-of-method-incompatible-with-super-type-class I've created a pull request nevertheless.

sandylaker commented 2 years ago

@rjagtani The mypy issue is because you modify the method signature to def evaluate(self, img: Tensor, smap: Tensor, img_path: str = None) -> Dict. Please use def evaluate(self, img: Tensor, smap: Tensor, target: int, **kwargs: Any) -> Dict. The img_path can be retrieved from the kwargs as img_path: str = kwargs["img_path"].