facebookresearch / detr

End-to-End Object Detection with Transformers
Apache License 2.0
13.09k stars 2.37k forks source link

[BUG] Tiny detail in the decoder layer is overlooked #528

Open JosephAssaker opened 1 year ago

JosephAssaker commented 1 year ago

First off, thank you for this inspiring contribution to the field!


Running DETR's code as is won't ever result in the manifestation of this bug, however it remains a "wrong" thing that I think needs fixing.

The bug resides in the of the forward_post and forward_pre methods of the TransformerDecoderLayer class. More precisely, it is in the call to the multihead_attn layer shown below:

https://github.com/facebookresearch/detr/blob/8a144f83a287f4d3fece4acdf073f387c5af387d/models/transformer.py#L224-L227

The attn_mask argument is set equal to memory_mask. However, and referring back to the official documentation the attn_mask is intended to either prevent a set of queries from attending to the keys/values or to manipulate the attention weight. Consequently, the shape of this argument should align with the shape of the query argument, which in this case is the tgt variable (being the object queries in DETR's case). Thus, the appropriate value for attn_mask here would be tgt_mask.

That being said, and with the current implementation, the bug will never manifest as the attn_mask argument is never passed in the call to the decoder layer in the code, leading to always having the attn_mask variable defaulting to None.


On another related note, not bug related though, why do you not set the attn_mask argument in the Encoder layer?

https://github.com/facebookresearch/detr/blob/8a144f83a287f4d3fece4acdf073f387c5af387d/models/transformer.py#L154-L156

You already have access to the mask of the source image(s), why not setting the attn_mask argument and preventing the padding positions to attend? Is it just for the sake of simplicity that you kept it None? As to my understanding, this shouldn't really affect the output as we will never use these positions as keys/values due to the key_padding_mask argument later on, right? If that's not the case, could you please clarify the reason of it being None?