hila-chefer / Transformer-MM-Explainability

[ICCV 2021- Oral] Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
MIT License
776 stars 106 forks source link

CLIP ViT-B/16 #8

Closed sanjayss34 closed 2 years ago

sanjayss34 commented 2 years ago

Hi Hila, thanks for your great work. I am trying to run the CLIP code from the notebook on the ViT-B/16 model, but I am getting attention maps that don't make any sense (not able to get similar results to what's in the notebook). For the ViT-B/32 model, I'm able to reproduce the results, but for some reason the ViT-B/16 model is causing an issue. Do you know why this is? The only things in the code I needed to change are:

hila-chefer commented 2 years ago

Hi @sanjayss34, thanks for your interest in our work! I haven’t tried using ViT-B/16 so I’m not sure. Could you maybe add a PR or attach your code somehow so I can look it over and see if I find something?

louisowen6 commented 2 years ago

Hi @hila-chefer, I'm also experimenting with ViT-B/16 model and did the exact same process with @sanjayss34 did. I'm not really sure if this is an issue or not, but my hypothesis is that the ViT-B/16 model gives more attention to local behavior since it has a lower patch size compared to ViT-B/32.

Can you please help me to confirm whether my hypothesis is correct or not? Hereby I attached the code and some of the comparison between ViT-B/16 and ViT-B/32, following this amazing notebook you've prepared.

Code:

# Add the link to ViT-B/16 in the_MODELS dictionary

clip.clip._MODELS = {
    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
    "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}
# Change the reshape in interpret from (1, 1, 7, 7) to (1, 1, dim, dim) 
# where dim = int(image_relevance.numel() ** 0.5)

def interpret(image, text, model, device, index=None):
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
    if index is None:
        index = np.argmax(logits_per_image.cpu().data.numpy(), axis=-1)
    one_hot = np.zeros((1, logits_per_image.size()[-1]), dtype=np.float32)
    one_hot[0, index] = 1
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.cuda() * logits_per_image)
    model.zero_grad()
    one_hot.backward(retain_graph=True)

    image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
    num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
    R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
    for blk in image_attn_blocks:
        grad = blk.attn_grad
        cam = blk.attn_probs
        cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
        cam = grad * cam
        cam = cam.clamp(min=0).mean(dim=0)
        R += torch.matmul(cam, R)
    R[0, 0] = 0
    image_relevance = R[0, 1:]

    # create heatmap from mask on image
    def show_cam_on_image(img, mask):
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)
        return cam

    reshape_size = int(image_relevance.numel() ** 0.5)
    image_relevance = image_relevance.reshape(1, 1, reshape_size, reshape_size)
    image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
    image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy()
    image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
    image = image[0].permute(1, 2, 0).data.cpu().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    vis = show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)

    plt.imshow(vis)
    plt.show()

image

litingfeng commented 2 years ago

@louisowen6 I was wondering if this is because the objects are large, e.g., the hightlight parts on the left column have more than 32 pixels (side length). Maybe try some small objects?

hila-chefer commented 2 years ago

Hi @sanjayss34, @louisowen6, @litingfeng, thank you all for this very lively discussion. I've been looking into what you're seeing with ViT-B/16, and I was able to reproduce quite easily. This is an interesting point- intuitively, you can think of our method as expanding context. The best way to think about it is- the last attention layer is the most important one, and then we use the previous ones to understand the meaning of each token in the last layer since it has been contextualized. Now, examining some of the examples for ViT-B/16, I noticed that the correct areas are being highlighted but are overshadowed by other areas in the image, for example: image while the elephant pixels are showing they are overshadowed by additional artifacts. This is probably due to an "over expansion" of context where the tokens highlighted in the first layers (which are way less significant to the classification) overpower those in the end (who are critical for classification). What I suggest- limit the expansion of context as follows: in interpret:

for i, blk in enumerate(image_attn_blocks):
        if i <=10:
          continue
        grad = blk.attn_grad
        cam = blk.attn_probs
        cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
        cam = grad * cam
        cam = cam.clamp(min=0).mean(dim=0)
        R += torch.matmul(cam, R)
    R[0, 0] = 0
    image_relevance = R[0, 1:]

notice how with these lines we skip the first 10 attention layers:

 if i <=10:
          continue

which yields the following improved results: text: a man with eyeglasses image text: a man with lipstick image text: a rocket standing on a launchpad image elephant & zebra examples: elephant: image zebra: image lake: image elephant: image water: image dog: image

cat: image dog: image cat: image

hila-chefer commented 2 years ago

I'll add the option to control context expansion to the notebook, I hope I was able to solve this for you :) please let me know if you have any other questions.

louisowen6 commented 2 years ago

Hi @hila-chefer, thanks a lot for your response and suggestion! I see that now the results are much more make sense than before. However. may I know what's the main reason why we have to treat the ViT-B/32 and ViT-B/16 differently? Is it due to their patch size difference? Thanks!

hila-chefer commented 2 years ago

@louisowen6 happy to help :) If you notice, I actually left 10 as the expansion factor in the notebook too- it’s beneficial for ViT-B/32 as well. My best guess on the gap in significance between both is that perhaps since there are more tokens for 16x16 patches, once an irrelevant token becomes highlighted it may expand the artifact to its neighbors (naturally) thus increasing the impact of artifacts. But, as I mentioned, it seems that for CLIP the optimal expansion parameter is 10 in both cases so there is consistency.

louisowen6 commented 2 years ago

Okay, thanks for your answer @hila-chefer! Really appreciate it :)

hila-chefer commented 2 years ago

@louisowen6 sure, feel free to ask if you have any other questions :) I’m closing this issue for now as it seems to be resolved, but @sanjayss34, @litingfeng if you have any follow up questions feel free to reopen.

snowkueen commented 2 years ago

Hi @hila-chefer, I'm also experimenting with ViT-B/16 model and did the exact same process with @sanjayss34 did. I'm not really sure if this is an issue or not, but my hypothesis is that the ViT-B/16 model gives more attention to local behavior since it has a lower patch size compared to ViT-B/32.

Can you please help me to confirm whether my hypothesis is correct or not? Hereby I attached the code and some of the comparison between ViT-B/16 and ViT-B/32, following this amazing notebook you've prepared.

Code:

# Add the link to ViT-B/16 in the_MODELS dictionary

clip.clip._MODELS = {
    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
    "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}
# Change the reshape in interpret from (1, 1, 7, 7) to (1, 1, dim, dim) 
# where dim = int(image_relevance.numel() ** 0.5)

def interpret(image, text, model, device, index=None):
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
    if index is None:
        index = np.argmax(logits_per_image.cpu().data.numpy(), axis=-1)
    one_hot = np.zeros((1, logits_per_image.size()[-1]), dtype=np.float32)
    one_hot[0, index] = 1
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.cuda() * logits_per_image)
    model.zero_grad()
    one_hot.backward(retain_graph=True)

    image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
    num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
    R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
    for blk in image_attn_blocks:
        grad = blk.attn_grad
        cam = blk.attn_probs
        cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
        cam = grad * cam
        cam = cam.clamp(min=0).mean(dim=0)
        R += torch.matmul(cam, R)
    R[0, 0] = 0
    image_relevance = R[0, 1:]

    # create heatmap from mask on image
    def show_cam_on_image(img, mask):
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)
        return cam

    reshape_size = int(image_relevance.numel() ** 0.5)
    image_relevance = image_relevance.reshape(1, 1, reshape_size, reshape_size)
    image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
    image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy()
    image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
    image = image[0].permute(1, 2, 0).data.cpu().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    vis = show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)

    plt.imshow(vis)
    plt.show()

image

hi, why can't I get results that make sense with your code, but use the author's source code can get, except "reshape_size = int(image_relevance.numel() ** 0.5), image_relevance = image_relevance.reshape(1, 1, reshape_size, reshape_size)" are valid

hila-chefer commented 2 years ago

Hi @snowkueen, could you please explain the issue? I wasn’t able to reproduce. Are you using our colab notebook?