Closed caraevangeline closed 2 years ago
Hello @caraevangeline , since you seem to be using SPP modules instead of SPPF I would suggest that you just switch to using a YOLO model that uses SPPF to get the fastest results since I have already implemented relevance propagation for modules of this class and the two modules function similarly. However, since I find your question very useful I will elaborate on how one should go about incorporating their custom modules to the relevance propagation framework.
I have included two ways to include custom modules to the relevance propagation framework by making them "acceptable" by the InnvestigateModel
used in the explain.py
script :
A. Implement the method propagate
in your custom module
The propagate method should take two arguments either than self
:
inverter
: An object of the class Inverter
which is responsible for the backward propagation of relevance through already registered modules. You can call the inverter module to propagate relevance through basic submodules that may be contained (i.e Conv2D layers).For example when propagating relevance through a layer
relevance
: A special LayerRelevance
tensor object that contains the upper layer relevance.If you go about it this way you might need to add some necessary forward hooks to your module for caching any needed intermediate results.
The output of the propagate
function should be the redistributed relevance going towards the lower layer.
B. Provide forward hooks and inversion functions directly to InnvestigateModel
This is the easiest way to make a custom module explainable by the InnvestigateModel
when de-serializing a pre-trained model. One simply has to supply the InnvestigateModel
with a dictionary, mapping each custom module to a forward hook (optionally) and inverse function that serves the same purpose as the method propagate
discussed previously. For example, in order to add some module named customMod
you simply need to add to the fwd_hooks
and inv_funcs
dictionaries the necessary entries as follows :
inn_model = InnvestigateModel(model, ...
fwd_hooks={ customMod : customMod_fwd_hook, ... },
inv_funcs={ customMod : custtomMod_inv_func, ... },
... )
where customMod_fwd_hook
, and custtomMod_inv_func
are callables that perform the desired function.
For an example of how to implement your own propagation/inverse functions see lrp/common.py
file included in the repo. More specifically since you want to add some version of an SPP module it might be useful for you to examine the prop_SPPF
function.
I hope you find this useful. If you have any more trouble I can try patching in models.common.SPP
myself.
@akarasman Thanks much for the details, I will try it out and if I have any problems I will get back to you
@akarasman I have tried the things you have suggested and was successful in getting it to run. I also faced an error for another module
NotImplementedError: Relevance propagation not implemented for layer type <class 'models.common.Focus'>
Since the Focus
layer involves Conv(), I added a function prop_Focus()
similar to prop_Conv()
and got the below results with my custom model
Whereas if I use official yolov5s.pt
, I get the below results
Is something going wrong in the way I have written these prop_Focus()
and prop_SPP()
functions or is it just the fact that the heatmap is over-concentrated in that yellow region?
For your reference: (the changes I have made)
/lrp/common.py
def prop_SPP(*args):
inverter, mod, relevance = args
#relevance = torch.cat([r.view(mod.m.out_shape) for r in relevance ], dim=0)
bs = relevance.size(0)
relevance = inverter(mod.cv2, relevance)
msg = relevance.scatter(which=-1)
ch = msg.size(1) // 4
r3 = msg[:, 3*ch:4*ch, ...]
r2 = msg[:, 2*ch:3*ch, ...] + r3
r1 = msg[:, ch:2*ch, ...] + r2
rx = msg[:, :ch, ...] + r1
msg = inverter(mod.cv1, rx)
relevance.gather([(-1, msg)])
return relevance
'''
'''
'''
def prop_Focus(*args):
inverter, mod, relevance = args
return inverter(mod.conv, relevance)
explain.py
inn_model = InnvestigateModel(model,
contrastive=contrastive,
power=power,
eps=1e-09,
entry_point=model.model.model,
pass_not_implemented=True,
fwd_hooks={ Concat : Concat_fwd_hook,
SPPF : SPPF_fwd_hook,},
inv_funcs={ C3 : prop_C3,
Conv : prop_Conv,
Detect : prop_Detect,
Bottleneck : prop_Bottleneck,
Concat : prop_Concat,
SPPF : prop_SPPF,
SPP: prop_SPP,
Focus: prop_Focus },
device=device)
Hello, sorry I could not reply yesterday. Judging from your output and looking at your propagation function implementations I would say the error is in the way you go about backward relevance propagation. You need to write your propagation functions in a way that is sensitive to the calculation the module you wish to explain performs. If you take a look at common.py
you can see that the modules you wish to explain perform the following operations :
class Focus(nn.Module):
...
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
# return self.conv(self.contract(x))
class SPP(nn.Module):
...
def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
My best advice would be to look carefully at what operations are performed and to consider how relevance should be redistributed, keeping in mind the formal definition of LRP.
Good Luck !
Hello, sorry I could not reply yesterday. Judging from your output and looking at your propagation function implementations I would say the error is in the way you go about backward relevance propagation. You need to write your propagation functions in a way that is sensitive to the calculation the module you wish to explain performs. If you take a look at
common.py
you can see that the modules you wish to explain perform the following operations :class Focus(nn.Module): ... def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) # return self.conv(self.contract(x))
class SPP(nn.Module): ... def forward(self, x): x = self.cv1(x) with warnings.catch_warnings(): warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
My best advice would be to look carefully at what operations are performed and to consider how relevance should be redistributed, keeping in mind the formal definition of LRP.
Good Luck !
@akarasman Thank you for all the explanations and prompt reply!
Hi @akarasman
Can I know how to implement layer relevance if I have to add a new module? I face the below mentioned error when running with my custom model trained with yolov5s architecture
Thanks in advance!