lyuwenyu / RT-DETR

[CVPR 2024] Official RT-DETR (RTDETR paddle pytorch), Real-Time DEtection TRansformer, DETRs Beat YOLOs on Real-time Object Detection. 🔥 🔥 🔥
Apache License 2.0
2.49k stars 288 forks source link

Spatial attention visualization #472

Open sepatin opened 1 week ago

sepatin commented 1 week ago

Hello

first of all thank you for your great work i would like to extract the cross attention maps to visualize spatial attention in a synchronized way to my images (during training on my val and also for inference on my batch)

i have set up hooks to capture these but i have difficulty selecting the right tensor (among the callbacks of the hook) and to correctly decompose it (in fact i can't determine value_spatial_shapes which seems to vary although my images are of fixed size) would you have any indications to give me on this subject or the way to do this ?

Greetings

sepatin commented 1 week ago

Update

I found a way to collect the attention layers and attention levels for each inference on my images (cross_attn to build an image to help "see" my detections)

Inference with torch.no_grad(): x = self.model.backbone(image) x = self.model.encoder(x) _, spatial_shapes = self.model.decoder._get_encoder_input(x) # get the spatial shapes for tensor decomposition _ = self.model.decoder(x, goals) Hook def register_hooks(self): for name, module in self.model.named_modules(): if 'cross_attn' in name: module.register_forward_hook(self.get_attention_hook)

For each inference with a batch of an image, I get 4 groups of attention layers (related to RTDETRTransformerv2::num_layers) The first tensor of the first layer can be decomposed with num_levels (RTDETRTransformerv2::num_levels)

I have new questions about this

Thanks for your advice S.