clemsgrs / hipt

Re-implementation of HIPT
18 stars 6 forks source link

For your paper Masked Attention #17

Open AlexNmSED opened 1 month ago

AlexNmSED commented 1 month ago

Hi, I noticed that you submitted a paper titled “Masked Attention as a Mechanism for Improving Interpretability of Vision Transformers” to Medical Imaging with Deep Learning 2024. Do you plan to integrate the code here.

clemsgrs commented 1 month ago

hi, the code is already here but it might be difficult to understand what is needed to make it run.

the masked attention class is defined in vision_transformer.py:

https://github.com/clemsgrs/hipt/blob/ffb75b3d3f59a033cdfecc401cac7270a7fccdc1/source/vision_transformer.py#L141-L193

models in models.py accept one (or two) new arguments:

mask_attn_patch: bool = False # whether to replace attention blocks with masked attention blocks in the patch-level Transformer
mask_attn_region: bool = False # whether to replace attention blocks with masked attention blocks in the region-level Transformer

these need to be added to your config file, under model (as in this file)

to make it work you also need to provide some more input to compute the masks based on regions & tissue mask:

https://github.com/clemsgrs/hipt/blob/ffb75b3d3f59a033cdfecc401cac7270a7fccdc1/config/training/classification/single.yaml#L11-L23

if that's not enough, i can try working on more detailed instructions.