hila-chefer / Transformer-Explainability

[CVPR 2021] Official PyTorch implementation for Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks.
MIT License
1.75k stars 232 forks source link

Training process #57

Closed ascdqz closed 1 year ago

ascdqz commented 1 year ago

Hello I'm trying to train a model(vit or deit) and implement this method. When I try to run it with general training process, I noticed that these two hook functions are registered in "def forward", and gives error when I do the validation process(because there will be model.eval() that disables hook registration). I tried remove these two functions, but it gives other errors, so I guess it's necessary? Also, I noticed that even though it has almost the same structure as pytorch models, but it can't directly load the dict of a pytorch vit model I traned, since the names are defined in a different way? And the code seems to be correlated tightly, I'm not too sure how to modify the regular model into the same shape as this. Those colab demos are very good. If you can give me some hint on how to use this method after training on my own dataset, that would be very helpful. Thank you! image