akarasman / yolo-heatmaps

A utility for generating heatmaps of YOLOv8 using Layerwise Relevance Propagation (LRP/CRP).
47 stars 13 forks source link

NotImplementedError: Relevance propagation not implemented for layer type <class 'models.common.SPP'> #4

Closed caraevangeline closed 2 years ago

caraevangeline commented 2 years ago

Hi @akarasman

Can I know how to implement layer relevance if I have to add a new module? I face the below mentioned error when running with my custom model trained with yolov5s architecture

NotImplementedError: Relevance propagation not implemented for layer type <class 'models.common.SPP'>

Thanks in advance!

akarasman commented 2 years ago

Hello @caraevangeline , since you seem to be using SPP modules instead of SPPF I would suggest that you just switch to using a YOLO model that uses SPPF to get the fastest results since I have already implemented relevance propagation for modules of this class and the two modules function similarly. However, since I find your question very useful I will elaborate on how one should go about incorporating their custom modules to the relevance propagation framework.

I have included two ways to include custom modules to the relevance propagation framework by making them "acceptable" by the InnvestigateModel used in the explain.py script :

A. Implement the method propagate in your custom module

The propagate method should take two arguments either than self :

  1. inverter : An object of the class Inverter which is responsible for the backward propagation of relevance through already registered modules. You can call the inverter module to propagate relevance through basic submodules that may be contained (i.e Conv2D layers).

For example when propagating relevance through a layer

  1. relevance : A special LayerRelevance tensor object that contains the upper layer relevance.

If you go about it this way you might need to add some necessary forward hooks to your module for caching any needed intermediate results.

The output of the propagate function should be the redistributed relevance going towards the lower layer.

B. Provide forward hooks and inversion functions directly to InnvestigateModel

This is the easiest way to make a custom module explainable by the InnvestigateModel when de-serializing a pre-trained model. One simply has to supply the InnvestigateModel with a dictionary, mapping each custom module to a forward hook (optionally) and inverse function that serves the same purpose as the method propagate discussed previously. For example, in order to add some module named customMod you simply need to add to the fwd_hooks and inv_funcs dictionaries the necessary entries as follows :

inn_model = InnvestigateModel(model, ...
                              fwd_hooks={ customMod :  customMod_fwd_hook, ... },
                              inv_funcs={ customMod :  custtomMod_inv_func, ... },
                              ... )

where customMod_fwd_hook, and custtomMod_inv_func are callables that perform the desired function.

For an example of how to implement your own propagation/inverse functions see lrp/common.py file included in the repo. More specifically since you want to add some version of an SPP module it might be useful for you to examine the prop_SPPF function.

I hope you find this useful. If you have any more trouble I can try patching in models.common.SPP myself.

caraevangeline commented 2 years ago

@akarasman Thanks much for the details, I will try it out and if I have any problems I will get back to you

caraevangeline commented 2 years ago

@akarasman I have tried the things you have suggested and was successful in getting it to run. I also faced an error for another module

NotImplementedError: Relevance propagation not implemented for layer type <class 'models.common.Focus'>

Since the Focus layer involves Conv(), I added a function prop_Focus() similar to prop_Conv() and got the below results with my custom model explain422_215_471_322 Whereas if I use official yolov5s.pt, I get the below results explain422_214_468_322 Is something going wrong in the way I have written these prop_Focus() and prop_SPP() functions or is it just the fact that the heatmap is over-concentrated in that yellow region?

For your reference: (the changes I have made)

/lrp/common.py

def prop_SPP(*args):

    inverter, mod, relevance = args

    #relevance = torch.cat([r.view(mod.m.out_shape) for r in relevance ], dim=0)
    bs = relevance.size(0)
    relevance = inverter(mod.cv2, relevance)
    msg = relevance.scatter(which=-1)
    ch = msg.size(1) // 4

    r3 = msg[:, 3*ch:4*ch, ...] 

    r2 = msg[:, 2*ch:3*ch, ...] + r3

    r1 = msg[:, ch:2*ch, ...] + r2

    rx = msg[:, :ch, ...] + r1

    msg = inverter(mod.cv1, rx)
    relevance.gather([(-1, msg)])

    return relevance
'''
'''
'''
def prop_Focus(*args):

    inverter, mod, relevance = args
    return inverter(mod.conv, relevance)

explain.py

inn_model = InnvestigateModel(model,
                                      contrastive=contrastive,
                                      power=power,
                                      eps=1e-09,
                                      entry_point=model.model.model,
                                      pass_not_implemented=True,
                                      fwd_hooks={ Concat : Concat_fwd_hook,
                                                  SPPF : SPPF_fwd_hook,},
                                      inv_funcs={ C3 : prop_C3,
                                                  Conv : prop_Conv,
                                                  Detect : prop_Detect,
                                                  Bottleneck : prop_Bottleneck,
                                                  Concat : prop_Concat,
                                                  SPPF : prop_SPPF,
                                                  SPP: prop_SPP,
                                                  Focus: prop_Focus },
                                      device=device)
akarasman commented 2 years ago

Hello, sorry I could not reply yesterday. Judging from your output and looking at your propagation function implementations I would say the error is in the way you go about backward relevance propagation. You need to write your propagation functions in a way that is sensitive to the calculation the module you wish to explain performs. If you take a look at common.py you can see that the modules you wish to explain perform the following operations :

class Focus(nn.Module):
    ...

    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
        return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
        # return self.conv(self.contract(x))
class SPP(nn.Module):
    ...

    def forward(self, x):
        x = self.cv1(x)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')  # suppress torch 1.9.0 max_pool2d() warning
            return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))

My best advice would be to look carefully at what operations are performed and to consider how relevance should be redistributed, keeping in mind the formal definition of LRP.

Good Luck !

caraevangeline commented 2 years ago

Hello, sorry I could not reply yesterday. Judging from your output and looking at your propagation function implementations I would say the error is in the way you go about backward relevance propagation. You need to write your propagation functions in a way that is sensitive to the calculation the module you wish to explain performs. If you take a look at common.py you can see that the modules you wish to explain perform the following operations :

class Focus(nn.Module):
    ...

    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
        return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
        # return self.conv(self.contract(x))
class SPP(nn.Module):
    ...

    def forward(self, x):
        x = self.cv1(x)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')  # suppress torch 1.9.0 max_pool2d() warning
            return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))

My best advice would be to look carefully at what operations are performed and to consider how relevance should be redistributed, keeping in mind the formal definition of LRP.

Good Luck !

@akarasman Thank you for all the explanations and prompt reply!