sandylaker / saliency-metrics

An Open-source Framework for Benchmarking Explanation Methods in Computer Vision
https://saliency-metrics.readthedocs.io
MIT License
1 stars 3 forks source link

[Feature Request] Implementation of Sanity Check #14

Open sandylaker opened 2 years ago

sandylaker commented 2 years ago

Implementation of Sanity Check

The implementation can be divided into the following parts:

SanityCheckResult

This class should implement the SerializableResult. It aims to record the sanity check results of images in a dataset, and can be dumped into a JSON file.

The class should look like this

class SanityCheckResult(SerializableResult):
    def __init__(self, summarized: bool = False, ...) -> None:
        ...

    def add_single_result(self, single_result: Dict) -> None:
        ...

    def dump(self, file_path: str) -> None:
        ...

An SanityCheckResult instance caches the results of images in a container: e.g. dict or list. The result of a single image should be a dict that contains following fields:

The dump method dumps the cached result into a JSON file. mmcv.dump can be used as a helper function here.

The flag summarized specifies whether we dump the raw result or we compute the means and stds of SSIM associated with different layer perturbation settings. I.e., the raw result should be like this:

{
    {"img_path": "path/to/img_0.jpg", "ssim": [0.9, 0.8, 0.7]},
    {"img_path": "path/to/img_1.jpg", "ssim": [0.8, 0.7, 0.6]},
}

while the summarized result should be like this:

{
    "mean_ssim": [0.85, 0.75, 0.65],
    "std_ssim": [0.031, 0.032, 0.026],
    "num_samples": 2,
}

SanityCheck Metric

The Class SanityCheck should implement the protocol ReInferenceMetric. It progressively perturbs the model layers, and compares the saliency map obtained under each perturbation setting with the original saliency map, i.e., the saliency map obtained from unperturbed model. The similarity between saliency maps are measured by structural similarity.

The class should look like this:

class SanityCheck(ReInferenceMetric):
    def __init__(self, classifier: nn.Module, perturb_layers: List[str], attr_method: AttributionMethod, ssim_args: Optional[Dict] = None, summarized: bool = False) -> None:
        self._result = SanityCheckResult(summarized)
        self._ori_state_dict = deepcopy(classifier.state_dict())
        self._classifier_layers: List[str] = self._filter_names([n[0] for n in classifier.named_modules()])
        ...

    def _reload_ckpt(self, device: Optional[Union[str, torch.device]] = None) -> None:
        ...

    def evaluate(self, torch.Tensor, smap: np.ndarray, target: int, **kwargs: Any) -> Dict:
        ...

    def _sanity_check_single(self, img: Tensor, target: int, consecutive_perturb_layers: List[str]) -> float:
        ...

    def update(self, single_result: Dict) -> None:
        ...

    @staticmethod
    def _perturb_classifier(classifier: nn.Module, layers: List[str]) -> None:
        ...

    def get_result(self) -> SerializableResult:
        return self._result

    @staticmethod
    def _filter_names(names: List[str]) -> List[str]:
        res: List[str] = []
        for i in range(len(names) - 1):
            if not names[i] in names[i+1]:
                res.append(names[i])
        res.append(names[-1])
        return res
  1. In the initialization method, attr_method represents the attribution method that implements AttributionMethod. _ori_state_dict is the state dict of the unperturbed model.

  2. The argument perturb_layers stores the progressive perturbation settings. Each element is a layer name, and it can be interpreted as "perturbing from the last layer to this layer".

  3. The static method _filter_names helps to filter the layer names. The layer names contained a recusive named module can be like this:

    ['', 'conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer1.0', 'layer1.0.conv1', 'layer1.0.bn1', 'layer1.0.act1', 'layer1.0.conv2', 'layer1.0.bn2', 'layer1.0.act2', 'layer1.1', 'layer1.1.conv1', 'layer1.1.bn1', 'layer1.1.act1', 'layer1.1.conv2', 'layer1.1.bn2', 'layer1.1.act2', 'layer2', 'layer2.0', 'layer2.0.conv1', 'layer2.0.bn1', 'layer2.0.act1', 'layer2.0.conv2', 'layer2.0.bn2', 'layer2.0.act2', 'layer2.0.downsample', ...]

    After being filtered, they are like this:

    ["conv1", "bn1", "act1", "maxpool", "layer1", "layer2", ...]
  4. The _reload_ckpt function does the following jobs:

    • reload the original (unperturbed) state_dict
    • send the model to the device if it is not None
    • turn the classifier into eval mode.
    • freeze the entire model.

It should be called when performing sanity check under each perturbation setting, i.e. at the beginning in the for-loop:

for layer in self.perturb_layers:
    self.reload_ckpt(device=img.device)
    ...
    self._sanity_check_single(...)
  1. _sanity_check_single performs sanity check under a specific perturbation setting, which is specified by the argument consecutive_perturb_layers.

  2. The static method _perturb_classifier perturbs a trunk of consecutive model layers.