Closed Cospel closed 4 years ago
I try to extend your code in this way and now it works for other models as ResNet, MobileNet and other (not only VGG), tested with 0.1.0 version:
class GradCAMFixed(GradCAM):
def __init__(self, preprocess_input=None):
super().__init__()
self.preprocess_input = preprocess_input
def explain(
self,
validation_data,
model,
class_index,
layer_name=None,
colormap=cv2.COLORMAP_VIRIDIS,
):
images, _ = validation_data
if self.preprocess_input:
input_images = [self.preprocess_input(image.copy()) for image in images]
else:
input_images = images
if layer_name is None:
layer_name = self.infer_grad_cam_target_layer(model)
outputs, guided_grads = GradCAM.get_gradients_and_filters(
model, input_images, layer_name, class_index
)
cams = GradCAM.generate_ponderated_output(outputs, guided_grads)
heatmaps = np.array(
[
heatmap_display(cam.numpy(), image, colormap)
for cam, image in zip(cams, images)
]
)
grid = grid_display(heatmaps)
return grid
@Cospel Why not doing the preprocessing before calling the explainer?
images, _ = validation_data
input_images = [preprocess_input(image.copy()) for image in images]
preprocessed_validation_data = (input_images, None)
explanations = GradCAM().explain(preprocessed_validation_data, ...)
Then the resulted image will not be nice. You need to concat result (heat map) with not preprocessed image so the explanation looks fine. However the input image for the model must be preprocessed. This is tested for tf-explain==0.1.0.
Ahh, in new version 0.2.0 you added method image_to_uint_255 for heatmap visualization which helps a lot to the input normalization problem! So the resulted images and the preprocessing before calling the explainer looks much better. Thank you for this update!
There are however architectures/models which are using different normalization of images which are not necessary in interval [0f, 1f], [-1f,1f] or [0int, 255int], which can lead to some artifacts on the resulted explanations.
I think, maybe you should close this for now. It will be great if in README/docs/examples are something about this and that user should (based on architecture) normalize images before inputting them to the explain method.
In a future release, I'll split creation of the attribution map from visualization, so people have more control over what they want to output and possibly create custom visualizations. Thanks for raising this!
Hi, thanks for great project. It works very well.
I have technical question about preprocessing the image for the cnn.
I looked into your examples, but normally every image which is input to VGG, ResNet, ... should be normalized before inferencing.
Why normalization is not applied here? The result of pretrained network in examples folder on not-normalized images can be wrong ....
Thank you!