hila-chefer / Transformer-Explainability

[CVPR 2021] Official PyTorch implementation for Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks.
MIT License
1.75k stars 232 forks source link

Question regarding fine tuning #40

Closed MichaelMMeskhi closed 2 years ago

MichaelMMeskhi commented 2 years ago

I am trying to fine tune on a custom image dataset. Using base 224 patch 16, I am setting the all params to false and changing the head output dims to 2 (binary classification). When trying to train this model, I get the error "cannot set a hook on a tensor that does not require grads".

I simply added a conditional statement in vit_lrp.py:

if self.train and x.requires_grad:
    x.register_hook()

if self.train and attn.requires_grad:
    attn.register_hook()

Just making sure this isn't breaking anything as this does fix my issue to fine tune.

hila-chefer commented 2 years ago

Hi @MichaelMMeskhi, thanks for your interest! The second hook you attached is originally used to calculate the gradients of each attention head (as our method averages across the heads using each head's gradient as weights). By not collecting the gradients, the method will not work (I think it will return None as the gradient information and then you should see an exception). However, since you do not need the relevance maps when training fine-tuning, this should have no effect on your visualization results after completing the fine-tuning. In fact, you can even fine-tune your model using the simple no LRP code as appears in the original implementation and then just save your weights and replace the url link in our code for visualization with your checkpoint.

I hope this helps, but please let me know if you have any follow-up questions.

Best, Hila.

MichaelMMeskhi commented 2 years ago

Thank you for your reply @hila-chefer! Yes, I did this originally for fine-tuning. Thank you, I managed to get things working.

Best, Michael