jacobgil / pytorch-grad-cam

Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more.
https://jacobgil.github.io/pytorch-gradcam-book
MIT License
10.44k stars 1.55k forks source link

pytorch grad cam for RGBD images? #406

Open shilpamatne opened 1 year ago

shilpamatne commented 1 year ago

Can I use this package with RGBD images? My custom model is a two stream network taking RGB image in one stream and depth image in the other. How can I specify input_tensor in such a case? Kindly advise.

jacobgil commented 1 year ago

If you want to create the CAM only for the RGB part,

I would create a model wrapper that fixes the D input, and is just a regular model with an RGB input.

Something like


class RGBDWrapper:
    def __init__(model, depth_tensor):
        self.model = model
        self.depth_tensor = depth_tensor

    def __call__(rgb_tensor):
        self.model(rgb_tensor, depth_tensor)

wrapped_model = RGBDWrapper(depth_tensor)

Then from the point of view of the CAM model, this wrapper model is just a model that accepts an RGB tensor.