Closed sanjayss34 closed 2 years ago
Hi @sanjayss34, thanks for your interest in our work! I haven’t tried using ViT-B/16 so I’m not sure. Could you maybe add a PR or attach your code somehow so I can look it over and see if I find something?
Hi @hila-chefer, I'm also experimenting with ViT-B/16 model and did the exact same process with @sanjayss34 did. I'm not really sure if this is an issue or not, but my hypothesis is that the ViT-B/16 model gives more attention to local behavior since it has a lower patch size compared to ViT-B/32.
Can you please help me to confirm whether my hypothesis is correct or not? Hereby I attached the code and some of the comparison between ViT-B/16 and ViT-B/32, following this amazing notebook you've prepared.
Code:
# Add the link to ViT-B/16 in the_MODELS dictionary
clip.clip._MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}
# Change the reshape in interpret from (1, 1, 7, 7) to (1, 1, dim, dim)
# where dim = int(image_relevance.numel() ** 0.5)
def interpret(image, text, model, device, index=None):
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
if index is None:
index = np.argmax(logits_per_image.cpu().data.numpy(), axis=-1)
one_hot = np.zeros((1, logits_per_image.size()[-1]), dtype=np.float32)
one_hot[0, index] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.cuda() * logits_per_image)
model.zero_grad()
one_hot.backward(retain_graph=True)
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
for blk in image_attn_blocks:
grad = blk.attn_grad
cam = blk.attn_probs
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=0)
R += torch.matmul(cam, R)
R[0, 0] = 0
image_relevance = R[0, 1:]
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
reshape_size = int(image_relevance.numel() ** 0.5)
image_relevance = image_relevance.reshape(1, 1, reshape_size, reshape_size)
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
plt.imshow(vis)
plt.show()
@louisowen6 I was wondering if this is because the objects are large, e.g., the hightlight parts on the left column have more than 32 pixels (side length). Maybe try some small objects?
Hi @sanjayss34, @louisowen6, @litingfeng, thank you all for this very lively discussion.
I've been looking into what you're seeing with ViT-B/16, and I was able to reproduce quite easily.
This is an interesting point- intuitively, you can think of our method as expanding context. The best way to think about it is- the last attention layer is the most important one, and then we use the previous ones to understand the meaning of each token in the last layer since it has been contextualized.
Now, examining some of the examples for ViT-B/16, I noticed that the correct areas are being highlighted but are overshadowed by other areas in the image, for example:
while the elephant pixels are showing they are overshadowed by additional artifacts.
This is probably due to an "over expansion" of context where the tokens highlighted in the first layers (which are way less significant to the classification) overpower those in the end (who are critical for classification).
What I suggest- limit the expansion of context as follows:
in interpret
:
for i, blk in enumerate(image_attn_blocks):
if i <=10:
continue
grad = blk.attn_grad
cam = blk.attn_probs
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=0)
R += torch.matmul(cam, R)
R[0, 0] = 0
image_relevance = R[0, 1:]
notice how with these lines we skip the first 10 attention layers:
if i <=10:
continue
which yields the following improved results: text: a man with eyeglasses text: a man with lipstick text: a rocket standing on a launchpad elephant & zebra examples: elephant: zebra: lake: elephant: water: dog:
cat: dog: cat:
I'll add the option to control context expansion to the notebook, I hope I was able to solve this for you :) please let me know if you have any other questions.
Hi @hila-chefer, thanks a lot for your response and suggestion! I see that now the results are much more make sense than before. However. may I know what's the main reason why we have to treat the ViT-B/32 and ViT-B/16 differently? Is it due to their patch size difference? Thanks!
@louisowen6 happy to help :) If you notice, I actually left 10 as the expansion factor in the notebook too- it’s beneficial for ViT-B/32 as well. My best guess on the gap in significance between both is that perhaps since there are more tokens for 16x16 patches, once an irrelevant token becomes highlighted it may expand the artifact to its neighbors (naturally) thus increasing the impact of artifacts. But, as I mentioned, it seems that for CLIP the optimal expansion parameter is 10 in both cases so there is consistency.
Okay, thanks for your answer @hila-chefer! Really appreciate it :)
@louisowen6 sure, feel free to ask if you have any other questions :) I’m closing this issue for now as it seems to be resolved, but @sanjayss34, @litingfeng if you have any follow up questions feel free to reopen.
Hi @hila-chefer, I'm also experimenting with ViT-B/16 model and did the exact same process with @sanjayss34 did. I'm not really sure if this is an issue or not, but my hypothesis is that the ViT-B/16 model gives more attention to local behavior since it has a lower patch size compared to ViT-B/32.
Can you please help me to confirm whether my hypothesis is correct or not? Hereby I attached the code and some of the comparison between ViT-B/16 and ViT-B/32, following this amazing notebook you've prepared.
Code:
# Add the link to ViT-B/16 in the_MODELS dictionary clip.clip._MODELS = { "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", }
# Change the reshape in interpret from (1, 1, 7, 7) to (1, 1, dim, dim) # where dim = int(image_relevance.numel() ** 0.5) def interpret(image, text, model, device, index=None): logits_per_image, logits_per_text = model(image, text) probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy() if index is None: index = np.argmax(logits_per_image.cpu().data.numpy(), axis=-1) one_hot = np.zeros((1, logits_per_image.size()[-1]), dtype=np.float32) one_hot[0, index] = 1 one_hot = torch.from_numpy(one_hot).requires_grad_(True) one_hot = torch.sum(one_hot.cuda() * logits_per_image) model.zero_grad() one_hot.backward(retain_graph=True) image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values()) num_tokens = image_attn_blocks[0].attn_probs.shape[-1] R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device) for blk in image_attn_blocks: grad = blk.attn_grad cam = blk.attn_probs cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1]) grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1]) cam = grad * cam cam = cam.clamp(min=0).mean(dim=0) R += torch.matmul(cam, R) R[0, 0] = 0 image_relevance = R[0, 1:] # create heatmap from mask on image def show_cam_on_image(img, mask): heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 cam = heatmap + np.float32(img) cam = cam / np.max(cam) return cam reshape_size = int(image_relevance.numel() ** 0.5) image_relevance = image_relevance.reshape(1, 1, reshape_size, reshape_size) image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear') image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy() image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min()) image = image[0].permute(1, 2, 0).data.cpu().numpy() image = (image - image.min()) / (image.max() - image.min()) vis = show_cam_on_image(image, image_relevance) vis = np.uint8(255 * vis) vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) plt.imshow(vis) plt.show()
hi, why can't I get results that make sense with your code, but use the author's source code can get, except "reshape_size = int(image_relevance.numel() ** 0.5), image_relevance = image_relevance.reshape(1, 1, reshape_size, reshape_size)" are valid
Hi @snowkueen, could you please explain the issue? I wasn’t able to reproduce. Are you using our colab notebook?
Hi Hila, thanks for your great work. I am trying to run the CLIP code from the notebook on the ViT-B/16 model, but I am getting attention maps that don't make any sense (not able to get similar results to what's in the notebook). For the ViT-B/32 model, I'm able to reproduce the results, but for some reason the ViT-B/16 model is causing an issue. Do you know why this is? The only things in the code I needed to change are:
interpret
from (1, 1, 7, 7) to (1, 1, dim, dim) wheredim = int(image_relevance.numel() ** 0.5)
Thanks!