The implementation can be divided into the following parts:
SanityCheckResult
This class should implement the SerializableResult. It aims to record the sanity check results of images in a dataset, and can be dumped into a JSON file.
An SanityCheckResult instance caches the results of images in a container: e.g. dict or list. The result of a single image should be a dict that contains following fields:
"img_path": the path of the input image.
"ssim": ndarray with shape (num_perturb_settings, ). It is the array of structural similarity associated with various layer perturbation settings.
The dump method dumps the cached result into a JSON file. mmcv.dump can be used as a helper function here.
The flag summarized specifies whether we dump the raw result or we compute the means and stds of SSIM associated with different layer perturbation settings. I.e., the raw result should be like this:
The Class SanityCheck should implement the protocol ReInferenceMetric. It progressively perturbs the model layers, and compares the saliency map obtained under each perturbation setting with the original saliency map, i.e., the saliency map obtained from unperturbed model. The similarity between saliency maps are measured by structural similarity.
In the initialization method, attr_method represents the attribution method that implements AttributionMethod. _ori_state_dict is the state dict of the unperturbed model.
The argument perturb_layers stores the progressive perturbation settings. Each element is a layer name, and it can be interpreted as "perturbing from the last layer to this layer".
The static method _filter_names helps to filter the layer names. The layer names contained a recusive named module can be like this:
Implementation of Sanity Check
The implementation can be divided into the following parts:
SanityCheckResult
This class should implement the SerializableResult. It aims to record the sanity check results of images in a dataset, and can be dumped into a JSON file.
The class should look like this
An
SanityCheckResult
instance caches the results of images in a container: e.g.dict
orlist
. The result of a single image should be a dict that contains following fields:"img_path"
: the path of the input image."ssim"
:ndarray
with shape(num_perturb_settings, )
. It is the array of structural similarity associated with various layer perturbation settings.The
dump
method dumps the cached result into a JSON file. mmcv.dump can be used as a helper function here.The flag
summarized
specifies whether we dump the raw result or we compute the means and stds of SSIM associated with different layer perturbation settings. I.e., the raw result should be like this:while the summarized result should be like this:
SanityCheck Metric
The Class
SanityCheck
should implement the protocol ReInferenceMetric. It progressively perturbs the model layers, and compares the saliency map obtained under each perturbation setting with the original saliency map, i.e., the saliency map obtained from unperturbed model. The similarity between saliency maps are measured by structural similarity.The class should look like this:
In the initialization method,
attr_method
represents the attribution method that implements AttributionMethod._ori_state_dict
is the state dict of the unperturbed model.The argument
perturb_layers
stores the progressive perturbation settings. Each element is a layer name, and it can be interpreted as "perturbing from the last layer to this layer".The static method
_filter_names
helps to filter the layer names. The layer names contained a recusive named module can be like this:After being filtered, they are like this:
The
_reload_ckpt
function does the following jobs:It should be called when performing sanity check under each perturbation setting, i.e. at the beginning in the for-loop:
_sanity_check_single
performs sanity check under a specific perturbation setting, which is specified by the argumentconsecutive_perturb_layers
.The static method
_perturb_classifier
perturbs a trunk of consecutive model layers.