hila-chefer / Transformer-Explainability

[CVPR 2021] Official PyTorch implementation for Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks.
MIT License
1.75k stars 232 forks source link

confused rgb/bgr? #35

Closed Tsingularity closed 2 years ago

Tsingularity commented 2 years ago

Hi thanks for the great work!

I am quite confused about this code segment in the Transformer_explainability.ipynb notebook:

# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

def generate_visualization(original_image, class_index=None):
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min()) 
    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy() # chw -> hwc (rgb)
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis

image = Image.open('samples/catdog.png')
dog_cat_image = transform(image)

fig, axs = plt.subplots(1, 3)
axs[0].imshow(image);
axs[0].axis('off');

output = model(dog_cat_image.unsqueeze(0).cuda())
print_top_classes(output)

# cat - the predicted class
cat = generate_visualization(dog_cat_image)

# dog 
# generate visualization for class 243: 'bull mastiff'
dog = generate_visualization(dog_cat_image, class_index=243)

axs[1].imshow(cat);
axs[1].axis('off');
axs[2].imshow(dog);
axs[2].axis('off');

as you can see, the image_transformer_attribution is loaded through PIL and transformed by pytorch, so it is in the RGB order.

however, the heatmap is constructed through cv2, so I assume it is in the BGR order? So why we can directly add them together i.e., cam = heatmap + np.float32(img) when they are not in the same color space?

Please correct me if my understanding is wrong here. Thanks!

hila-chefer commented 2 years ago

Hi @Tsingularity, thanks for your interest!

I apologize for the delay in response. This code was adapted from this great repo. If I recall correctly, it also contains full documentation of all the details.

Tsingularity commented 2 years ago

Thanks for the reply and reference.

But could you please check the notebook code above and my question in detail? I am not saying the attention code is incorrect. It just seems the visualization part is a little bit buggy in terms of 'BGR / RGB'.

Thanks!

hila-chefer commented 2 years ago

I was actually referring to this code it’s exactly the code that applies the heatmap to the image that you copied here. Does this help your confusion?

Tsingularity commented 2 years ago

Thanks for pointing the reference code.

But I am afraid that's where the bug comes from? I mean as u can see in the original code, it is using 'BGR' by default (for both img and mask).

However, in ur notebook code above, the input image image_transformer_attribution is in RGB, but the mask heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) is in BGR. And u directly add them together while they are actually in different color space:cam = heatmap + np.float32(img). That's the reason why I think it doesn't make sense here.

But I could be wrong anyway. Feel free to correct me if there's something incorrect in my understanding.

Thanks!

hila-chefer commented 2 years ago

I understand your point, I'm really no expert in these BGR/RGB games, when I try applying the code with the RGB conversion (L. 40 in the reference code), I get identical heatmaps. It is possible that this is working due to this conversion in the notebook code (in generate_visualization):

vis =  np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)

If you wish to use the RGB conversion from L. 40, you can simply remove the above 2 lines like this:

def show_cam_on_image(img: np.ndarray,
                      mask: np.ndarray,
                      colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
    """ This function overlays the cam mask on the image as an heatmap.
    By default the heatmap is in BGR format.
    :param img: The base image in RGB or BGR format.
    :param mask: The cam mask.
    :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
    :param colormap: The OpenCV colormap to be used.
    :returns: The default image with the cam overlay.
    """
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255

    if np.max(img) > 1:
        raise Exception(
            "The input image should np.float32 in the range [0, 1]")

    cam = heatmap + img
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)
def generate_visualization(original_image, class_index=None):
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image2(image_transformer_attribution, transformer_attribution)
    return vis

Here are the heatmaps in both cases: original: image image

with RGB conversion: image image

Since the heatmaps are identical in both cases, I guess this does not make much of a difference, but thanks for pointing this out!

Tsingularity commented 2 years ago

thanks for looking into this issue and re-running the visualization code.

I also just wrote a correct version on my end and got the same results as u showed above. It's interesting to see this issue (seemingly) doesn't affect the final image output at all.

Since the code is in notebook and I find it hard to make a pull request inside jupyter notebooks, I'll just leave the modification work to u :)

Thanks for the help again!