sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 110 forks source link

Preprocessing image #107

Closed Cospel closed 4 years ago

Cospel commented 4 years ago

Hi, thanks for great project. It works very well.

I have technical question about preprocessing the image for the cnn.

I looked into your examples, but normally every image which is input to VGG, ResNet, ... should be normalized before inferencing.

Why normalization is not applied here? The result of pretrained network in examples folder on not-normalized images can be wrong ....

Thank you!

Cospel commented 4 years ago

I try to extend your code in this way and now it works for other models as ResNet, MobileNet and other (not only VGG), tested with 0.1.0 version:

class GradCAMFixed(GradCAM):
    def __init__(self, preprocess_input=None):
        super().__init__()
        self.preprocess_input = preprocess_input

    def explain(
        self,
        validation_data,
        model,
        class_index,
        layer_name=None,
        colormap=cv2.COLORMAP_VIRIDIS,
    ):
        images, _ = validation_data

        if self.preprocess_input:
            input_images = [self.preprocess_input(image.copy()) for image in images]
        else:
            input_images = images

        if layer_name is None:
            layer_name = self.infer_grad_cam_target_layer(model)

        outputs, guided_grads = GradCAM.get_gradients_and_filters(
            model, input_images, layer_name, class_index
        )

        cams = GradCAM.generate_ponderated_output(outputs, guided_grads)

        heatmaps = np.array(
            [
                heatmap_display(cam.numpy(), image, colormap)
                for cam, image in zip(cams, images)
            ]
        )

        grid = grid_display(heatmaps)

        return grid
RaphaelMeudec commented 4 years ago

@Cospel Why not doing the preprocessing before calling the explainer?

images, _ = validation_data

input_images = [preprocess_input(image.copy()) for image in images]

preprocessed_validation_data = (input_images, None)

explanations = GradCAM().explain(preprocessed_validation_data, ...)
Cospel commented 4 years ago

Then the resulted image will not be nice. You need to concat result (heat map) with not preprocessed image so the explanation looks fine. However the input image for the model must be preprocessed. This is tested for tf-explain==0.1.0.

Cospel commented 4 years ago

Ahh, in new version 0.2.0 you added method image_to_uint_255 for heatmap visualization which helps a lot to the input normalization problem! So the resulted images and the preprocessing before calling the explainer looks much better. Thank you for this update!

There are however architectures/models which are using different normalization of images which are not necessary in interval [0f, 1f], [-1f,1f] or [0int, 255int], which can lead to some artifacts on the resulted explanations.

I think, maybe you should close this for now. It will be great if in README/docs/examples are something about this and that user should (based on architecture) normalize images before inputting them to the explain method.

RaphaelMeudec commented 4 years ago

In a future release, I'll split creation of the attribution map from visualization, so people have more control over what they want to output and possibly create custom visualizations. Thanks for raising this!