pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.81k stars 487 forks source link

Tutorial on explainablity of Detectron2 maskrcnn #872

Open INF800 opened 2 years ago

INF800 commented 2 years ago

Hi, is there a tutorial on Detectron2 mask-rcnn explainablity (specifically for instance segmentation)? I read another issue mentioning that it it is difficult. I would like to write a tutorial notebook for it with guidance if it is possible.

NarineK commented 2 years ago

@INF800, I'm not aware of any Detectron2 tutorial with Captum. The only tutorial that we have on segmentation is semantic segmentation (https://captum.ai/tutorials/Segmentation_Interpret) that you probably have also seen. Sure, feel free to propose the tutorial, we will try to help. @vivekmig, @bilalsal do you know if we have any examples for Detectron2 mask-rcnn explainablity for semantic segmentation ? Thank you!

INF800 commented 2 years ago

Sure! I've been meaning to do this but didn't because of lack of guidance. Do you think we can extend semantic segmentation explanation to instance segmentation?

NarineK commented 2 years ago

yes, I think we should be able to take similar approach for instance segmentation as well. If the classifier predicts the instances we should be able to attribute to individual instances. cc: @vivekmig

INF800 commented 2 years ago

Can I know the steps briefly that I need to follow to pull this off? (Specific to Detectron2)

Along with the best practices.

dorbittonn commented 4 months ago

I would like to join to the request above, how can we attribute to individual instance? For example: If I have MaskRCNN arch. the instance segmentation head takes as an input only the backbone's embeddings after pooling. The learnable segmentation mask is (28,28) and I want to modify the wrapper in the semantic segmentation tutorial

I thought on somethings as follows: but I dont really get how to a part of the image, becuase maybe there are other parts in the image out of the bbox that affect the segmentation...

def agg_segmentation_wrapper(inp):
images = fcn.preprocess_image([{'image': inp[0], 'height': img.shape[1], 'width': img.shape[2]}])
features = fcn.backbone(images.tensor)
proposals, _ = fcn.proposal_generator(images, features, None)
results, _ = fcn.roi_heads(images, features, proposals, None)

# choose specific instance
instance_i=0
model_out = results[0].pred_masks[instance_i]

# Creates binary matrix with 1 for original argmax class for each pixel
# and 0 otherwise. Note that this may change when the input is ablated
# so we use the original argmax predicted above, out_max.
selected_inds = torch.zeros_like(model_out[0:1]).scatter_(1, out_max, 1)
return (model_out * selected_inds).sum(dim=(1, 2)).reshape(1, -1)

@NarineK @INF800