frgfm / torch-cam

Class activation maps for your PyTorch models (CAM, Grad-CAM, Grad-CAM++, Smooth Grad-CAM++, Score-CAM, SS-CAM, IS-CAM, XGrad-CAM, Layer-CAM)
https://frgfm.github.io/torch-cam/
Apache License 2.0
1.99k stars 205 forks source link

Make target_layer also a `nn.Module` #81

Closed FrancescoSaverioZuppichini closed 2 years ago

FrancescoSaverioZuppichini commented 3 years ago

🚀 Feature

Hello there,

I would like to pass an nn.Module instance to any *Cam constructor but currently, I can only use 'str'.

Motivation

Well, if your model has a good design it is much easier to pass a reference than the key.

Thanks!

Francesco

frgfm commented 3 years ago

Hi @FrancescoSaverioZuppichini :wave:

Thanks for the suggestion! I thought about it a while ago, but couldn't really see cases where it's better on the user-side than passing the string :thinking:

Currently

# Define your model
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval()

# Set your CAM extractor
from torchcam.cams import SmoothGradCAMpp
cam_extractor = SmoothGradCAMpp(model, 'layer4')

and with your suggestion

# Define your model
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval()

# Set your CAM extractor
from torchcam.cams import SmoothGradCAMpp
cam_extractor = SmoothGradCAMpp(model, model.layer4)

I agree that passing a reference rather than a string we'll get a reference from later on is cleaner from the programming point of view but I can see at least one drawback: if you pass the target layer as a reference to another model by mistake, the CAM won't work. Also this introduces some redundancy in arguments: on top of the target layer, the whole model has to be passed (to clear gradients, and for forwards in some specific CAM methods).

Could you tell me in which cases the string argument is not enough on your end please? :pray:

FrancescoSaverioZuppichini commented 3 years ago

Hi @frgfm ,

Thank you for your reply. I hope you are doing great ;) IMHO Passing reference is always better than passing a string.

First of all, if I pass a string you'll need to trace the model to find out where it is. Secondly, maybe I have no idea what my "key" is. What if I want to reference by index a list of layers? I cannot do that, by reference it is as easy as model.layers[idx] . Or what if I have a very nested model, something like model.encoder.blocks[1].head.projection, by reference is dead easy to pass what you need, but by string? What about collisions? If I have two layers called in the same way, as it is usually done in practice, the current implementation will not be able to understand which one is the correct one

You can always support strings keys by having something .from_key(model, key: str)

It is not a big deal, I would love to use your *Cams in my projects but I can just joink the code and change it

Thank you again!

Cheers,

Francesco

frgfm commented 3 years ago

Being a python perfectionist, I obviously agree with you that passing a ref is better than a string :ok_hand:

However here there are a few things to consider:

I chose the string interface initially because it matches the way state_dict are created for a given Module. In both your examples, you would get the exact same results with a string argument by passing "layers.idx" and "encoder.blocks.1.head.projection" respectively. So I'm really curious about why the string argument is a limitation usage-wise, I'd really like to understand so that I can improve the library support :pray:

Anyway, if this is something better for the community, I'm happy to add support for this!

FrancescoSaverioZuppichini commented 3 years ago

Hey @frgfm , thank you for your reply and the amazing discussion :) I love your enthusiasm!

But I was thinking, if you pass a string then you need to trace the module to find it, would it be more convenient to pass a reference?

Image the following situation, I have a model, I know which layer I want to index, why should I waste computation doing one forward pass? I can just pass my reference and boom we hook the hook :) What do you think?

Thank you again!

Francesco :)

frgfm commented 3 years ago

My apologies Francesco, I've been quite busy with other projects lately!

I'm not sure I follow the reasoning about wasted computation? :sweat_smile: Currently, there are two cases:

  1. A value for target_layer is provided: the layer is retrieved in O(1) complexity since it's just accessing the naming dictionary
  2. No value is provided: one forward is "wasted" to retrieve the best candidate layer

Now if reference passing is added:

  1. Value provided: the only difference is that the dictionary doesn't have to be built, but it's a dummy call of torch.nn.Module.named_modules so I would argue it really makes no difference
  2. No value is provided: same process as before

_In my opinion, with a reference is provided, it would be even somehow important to do a dummy inference to ensure the module indeed belongs to this model :sweatsmile: I guess a warning could be thrown when trying to compute the cam if the reference was incorrect (but that would have been caught in the very constructor with the string arg)

My best suggestion would be then:

What do you think? :)

FrancescoSaverioZuppichini commented 2 years ago

hey @frgfm , what's up! Recently PyTorch releases a new way to extract features. I was thinking you may find it interesting :)

frgfm commented 2 years ago

Hi @FrancescoSaverioZuppichini :wave:

Sorry about the late reply! Yup, I saw that in the previous release notes, and wanted to try it out. I'll see if that can help this project :+1: