Closed FrancescoSaverioZuppichini closed 2 years ago
Hi @FrancescoSaverioZuppichini :wave:
Thanks for the suggestion! I thought about it a while ago, but couldn't really see cases where it's better on the user-side than passing the string :thinking:
Currently
# Define your model
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval()
# Set your CAM extractor
from torchcam.cams import SmoothGradCAMpp
cam_extractor = SmoothGradCAMpp(model, 'layer4')
and with your suggestion
# Define your model
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval()
# Set your CAM extractor
from torchcam.cams import SmoothGradCAMpp
cam_extractor = SmoothGradCAMpp(model, model.layer4)
I agree that passing a reference rather than a string we'll get a reference from later on is cleaner from the programming point of view but I can see at least one drawback: if you pass the target layer as a reference to another model by mistake, the CAM won't work. Also this introduces some redundancy in arguments: on top of the target layer, the whole model has to be passed (to clear gradients, and for forwards in some specific CAM methods).
Could you tell me in which cases the string argument is not enough on your end please? :pray:
Hi @frgfm ,
Thank you for your reply. I hope you are doing great ;) IMHO Passing reference is always better than passing a string.
First of all, if I pass a string you'll need to trace the model to find out where it is. Secondly, maybe I have no idea what my "key" is. What if I want to reference by index a list of layers? I cannot do that, by reference it is as easy as model.layers[idx]
. Or what if I have a very nested model, something like model.encoder.blocks[1].head.projection
, by reference is dead easy to pass what you need, but by string? What about collisions? If I have two layers called in the same way, as it is usually done in practice, the current implementation will not be able to understand which one is the correct one
You can always support strings keys by having something .from_key(model, key: str)
It is not a big deal, I would love to use your *Cam
s in my projects but I can just joink the code and change it
Thank you again!
Cheers,
Francesco
Being a python perfectionist, I obviously agree with you that passing a ref is better than a string :ok_hand:
However here there are a few things to consider:
subpart.1
pointing towards the second item of the nn.Sequential
called subpart
in your model for instance since #21 :smile: torch.nn.Module
as the second time you assign a value will override the reference.I chose the string interface initially because it matches the way state_dict
are created for a given Module. In both your examples, you would get the exact same results with a string argument by passing "layers.idx"
and "encoder.blocks.1.head.projection"
respectively. So I'm really curious about why the string argument is a limitation usage-wise, I'd really like to understand so that I can improve the library support :pray:
Anyway, if this is something better for the community, I'm happy to add support for this!
Hey @frgfm , thank you for your reply and the amazing discussion :) I love your enthusiasm!
But I was thinking, if you pass a string then you need to trace the module to find it, would it be more convenient to pass a reference?
Image the following situation, I have a model, I know which layer I want to index, why should I waste computation doing one forward pass? I can just pass my reference and boom we hook the hook :) What do you think?
Thank you again!
Francesco :)
My apologies Francesco, I've been quite busy with other projects lately!
I'm not sure I follow the reasoning about wasted computation? :sweat_smile: Currently, there are two cases:
Now if reference passing is added:
torch.nn.Module.named_modules
so I would argue it really makes no difference_In my opinion, with a reference is provided, it would be even somehow important to do a dummy inference to ensure the module indeed belongs to this model :sweatsmile: I guess a warning could be thrown when trying to compute the cam if the reference was incorrect (but that would have been caught in the very constructor with the string arg)
My best suggestion would be then:
What do you think? :)
hey @frgfm , what's up! Recently PyTorch releases a new way to extract features. I was thinking you may find it interesting :)
Hi @FrancescoSaverioZuppichini :wave:
Sorry about the late reply! Yup, I saw that in the previous release notes, and wanted to try it out. I'll see if that can help this project :+1:
🚀 Feature
Hello there,
I would like to pass an
nn.Module
instance to any*Cam
constructor but currently, I can only use 'str'.Motivation
Well, if your model has a good design it is much easier to pass a reference than the key.
Thanks!
Francesco