hila-chefer / Transformer-Explainability

[CVPR 2021] Official PyTorch implementation for Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks.
MIT License
1.75k stars 232 forks source link

How to accelerate when inferring CAM in parallel #32

Closed Muyun99 closed 2 years ago

Muyun99 commented 2 years ago

Hello there, Thanks for your nice work and code sharing.

I currently have a need to perform parallel inference on img tensor with {B, C, H, W} shape.

I am now calculating for each category as well as each image individually, and this method is really slow.

Since I am using the calculation results for design loss, the current speed is not acceptable.

I would like to know if you have any suggested speedup method to inference the cam in parallel.

Looking forward to seeing your reply, thx

hila-chefer commented 2 years ago

Hi @Muyun99, thanks for your interest in our work! Currently, we do not support batching for this code, but it can be implemented. To see an example of how batching can be implemented for gradient propagation, please see this issue from our second paper, where I added support of batches to CLIP explainability.

I hope this helps. Best, Hila.

Muyun99 commented 2 years ago

Hi @hila-chefer , thank you for your response.

I have a try to shift the solution from the new repo. It`s very easy, thanks for your awesome works.

But I find a problem that the performance of two solutions that generate CAMs(class activation maps) have a difference.

Details

I run the code in VOC2012Aug dataset with vit-base to validate the quality of the CAM.

The solution from Transformer-Explainability will have a mIoU 50.438, and Transformer-MM-Explainability can only reach mIoU 41.562.

Could you explain the difference simply and is there any solution to bridge the performance gap?

Looking forward to seeing your reply, thx

hila-chefer commented 2 years ago

Hi @Muyun99, am I correct in assuming you are estimating semantic segmentation using the explainability maps? If so- could you please elaborate on how you produce the segmentation from the explainability (how do you binarize the maps)? That could have an impact on how well each method does for semantic segmentation.

Muyun99 commented 2 years ago

Hi @hila-chefer, Yes I want to use this explainability as the groundtruth of semantic segmentation task.

I use this code to evaluate the quality of the CAM, I will inference the map for each class, and use a threshold to convert it to binary map. And suitable thresholds are searched in an interval which is set as [0, 0.6] in this code.

Then I can calculate the mIoU with segmentation groundtruth for quantitative evaluation of the quality of the explainability maps.

hila-chefer commented 2 years ago

I’m not familiar with the code, but we have experimented with both papers and saw that the gap is usually not significant. With that being said, if you look at the notebooks for ViT with both methods, our gradient based method sometimes outputs smaller relevance values, which possibly requires a different method for proper binarization. My guess is that this is the source of the gap you see. On our ICCV paper the Otsu threshold works nicely for DETR, perhaps you could try that?

hila-chefer commented 2 years ago

Closing due to inactivity, please re-open if necessary