jacobgil / pytorch-grad-cam

Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more.
https://jacobgil.github.io/pytorch-gradcam-book
MIT License
10.06k stars 1.52k forks source link

Can GradCAM be used in Transformer? #403

Open WhatAShot opened 1 year ago

WhatAShot commented 1 year ago

GradCAM is initially devised for CNNs, but can GradCAM be available for Transformer or some other architectures with self-attention?

marios1861 commented 1 year ago

You can use GradCAM in transformers by reshaping the intermediate activations into CNN-like 4D tensors. There is a parameter in, I think, every implemented method on the library called reshape_transform. You can give it a simple batch+2D tensor to batch+3D tensor reshaping function. There is an example in the wiki I think, I use this:

def reshape_transform(tensor, height=14, width=14):
    result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))

    # Bring the channels to the first dimension,
    # like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result

Edit: You can find this exact function in the wiki

jacobgil commented 1 year ago

There are also many examples with different transformer variants here: https://jacobgil.github.io/pytorch-gradcam-book/HuggingFace.html