hila-chefer / Transformer-Explainability

[CVPR 2021] Official PyTorch implementation for Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks.
MIT License
1.75k stars 232 forks source link

On DEIT distilled network #30

Closed kspruthviraj closed 2 years ago

kspruthviraj commented 2 years ago

Hi Hila,

Thanks for sharing the script to visualize the attention maps.

I am trying to run your DEIT example on custom model (DEIT-base distilled network with 19 classes), but so far have been unsuccessful. I keep getting this error "AttributeError: 'VisionTransformer' object has no attribute 'relprop'"

Here is my saved model weights and biases: https://drive.switch.ch/index.php/s/dimybgHdzyE90gB

This is how I first load the model from Timm:

basemodel = timm.create_model('deit_base_distilled_patch16_224', pretrained=True, num_classes=19)
model = basemodel

Then I load trained weights and biases from my model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model)
model.to(device)
criterion = nn.CrossEntropyLoss()
torch.cuda.set_device(0)
model.cuda(0)
criterion = criterion.cuda(0)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5,weight_decay=3e-5)

PATH = checkpoint_path+'/trained_model.pth'. # Saved model path -- Shared in the link earlier.
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

When I try to use this model in your example code, I keep getting this error "AttributeError: 'VisionTransformer' object has no attribute 'relprop'"

I was wondering if you would have time to take a look at the uploaded model and see whether you can generate attention map ?

Thanks a lot

hila-chefer commented 2 years ago

Hi @kspruthviraj, thanks for your interest in our work! Please notice that I do not use the out of the box implementation of ViT. My code contains a modified implementation which adds a relevance propagation function for each layer in the network. Thus, when you load your weights, load them to the model implemented in this repo to get LRP propagation working.

Best, Hila

kspruthviraj commented 2 years ago

Hi @hila-chefer ,

Thanks for getting back to me. Okay now I get it.

I guess the models implemented in this repo are DEIT-smal and DEIT-base networks and it does not have DEIT-BASE-DISTILLED. Since, I am using the DEIT-base-Distilled network I might not be able to load the weights to the models implemented in to this repo directly.

Best, Sreenath