hila-chefer / Transformer-Explainability

[CVPR 2021] Official PyTorch implementation for Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks.
MIT License
1.75k stars 232 forks source link

what should i do when trying this visualization method on cifar dataset #20

Closed caposerenity closed 3 years ago

caposerenity commented 3 years ago

Hi, i'm wondering how to modify the generate_visualization( ) method in the demo when trying this on CIFAR-100 with img_size=3232 instead of 224224 due to limit of my GPU memory, i failed setting the input size to my vit model as 224

Thanks!

hila-chefer commented 3 years ago

Hi @caposerenity , thanks for your interest in our work! Apologies for the delayed response. Of course you can, I think all you need to do is to load different weights to our ViT implementation, see this function for example. After you load your own pretrained weights, just modify the reshapes in the generate_visualization( ) function, transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14) this line reshapes to the shape of the patch dimensions, and later interpolate to 32x32.

I hope this helps, let me know if you require any clarifications. Also, if this works, I encourage you to add a PR with your changes for other people to use :)