fundamentalvision / Deformable-DETR

Deformable DETR: Deformable Transformers for End-to-End Object Detection.
Apache License 2.0
3.22k stars 520 forks source link

Attention Map #156

Open amindehnavi opened 2 years ago

amindehnavi commented 2 years ago

Hi, is there any way to generate the output attention maps of model.transformer.decoder.layers[i].cross_attn layer? when I follow the referenced functions, I finally get stuck in MSDA.ms_deform_attn_forward function in the forward method of the MSDeformAttnFunction class which is located at ./models/ops/functions/ms_deform_attn_func.py file, and I couldn't find any argument to set True to get the attention map in output.

./models/deformable_transformer.py -> [Class] DeformableTransformerDecoderLayer image

./models/ops/modules/ms_deform_attn.py -> [Class] MSDeformAttn -> forward function image

./models/ops/functions/ms_deform_attn_func.py image

GivanTsai commented 2 years ago

Have you figured out how to draw attention map of encoder and decoder?

owen24819 commented 2 years ago

Any update on this?

amindehnavi commented 2 years ago

Any update on this?

Unfortunately no. as mentioned in the paper, the self-attention blocks are the same as DETR, which are instantiated from torch.nn.MultiheadAttention class but the cross-attention block is built on MSDeformAttentionFunction(), and I have not figured out how to reach the attention weights from that. If you find out how to do that, please tell us :)

ZhixiongSun commented 1 year ago

the author provided a pytorch version for debuging purpose. just used function ms_deform_attn_core_pytorch in ./models/ops/functions/ms_deform_attn_func.py instead of cuda version. You can call this function as following( in class MSDeformAttn ):

output = MSDeformAttnFunction.apply(

    #     value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
    output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)

Just comment the cuda version MSDeformAttnFunction in class MSDeformAttn(nn.Module) and use pytorch version