Lightning-Universe / lightning-bolts

Toolbox of models, callbacks, and datasets for AI/ML researchers.
https://lightning-bolts.readthedocs.io
Apache License 2.0
1.68k stars 320 forks source link

Add Grad-CAM implementation #243

Open asaporta opened 3 years ago

asaporta commented 3 years ago

🚀 Feature

Implementation of Grad-CAM, probably as a vision callback.

Motivation

Grad-CAM is widely used localization method that uses the gradient information flowing into the last convolutional layer of a CNN to visualize the model's prediction by highlighting the "important" pixels in the image.

The technique does not require any modifications to the existing model architecture, so can be applied to any CNN-based architecture, including those for image captioning and visual question answering.

Pitch

Right now, for research we're doing in the lab I'm working in, I've been using a modified version of this PyTorch implementation of Grad-CAM, which only works on batch_size = 1 (but it looks like there are many other PyTorch implementations, with many stars, on GitHub that we could work off of).

For the above research, I've added an if statement to the test_step function in our LightningModule so that if we want the cams to be saved during inference, it calls a separate util function localize that does the forward and backward pass to create the feature maps. I'm not sure this is ideal, though, because we later do another forward and backward pass on the same image to get the prediction, so there is duplicated work.

I was thinking that it would be nice to have some sort of callback that can just generate (and save?) the cams for you, without having to mess with the training pipeline. I guess we would have to figure out where the the cams would be saved, and whether just the heatmap would be saved, or the heatmap overlaid on the original image (which is probably the most helpful?).

Alternatives

Other localization methods include Integrated Gradients, WILDCAT, and Grad-CAM++, but Grad-CAM seems to be the most widely-used.

Additional context

I'm not sure if this is helpful, but to clarify further how I'm currently doing this in Lightning: I've created inference_step and inference_epoch_end functions that both my valid functions and my test functions in the LightningModule call (that way, we can make sure that both valid and test are doing inference in the same way). Only my test_step has a separate if statement that's called only if the user wants to also generate cams.

Anyway, I'd love to help out on this in any way I can! I've never written a callback before, though, so would need some guidance on how to approach that.

github-actions[bot] commented 3 years ago

Hi! thanks for your contribution!, great first issue!

edgarriba commented 3 years ago

@ASaporta in case this goes through, after a quick look at the code I see at least a couple of potential places where some utilities from kornia could be used here, resize and normalize_min_max

Kshitij09 commented 3 years ago

How about integrating captum for the same? They've wide range of interpretability alogorithms implemented with rigorous testing. We could simply write a Callback for captum and add 'captum' as explicit dependency, throwing exception at runtime. Captum also provides utility functions for visualization so it'll save a lot of efforts if we just use the library.

cc: @ASaporta

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Spawnfile commented 3 years ago

How about integrating captum for the same? They've wide range of interpretability alogorithms implemented with rigorous testing. We could simply write a Callback for captum and add 'captum' as explicit dependency, throwing exception at runtime. Captum also provides utility functions for visualization so it'll save a lot of efforts if we just use the library.

cc: @ASaporta

Thanks for suggestion.