facebookresearch / MaskFormer

Per-Pixel Classification is Not All You Need for Semantic Segmentation (NeurIPS 2021, spotlight)
Other
1.35k stars 152 forks source link

Question about transformer decoder #63

Open CoinCheung opened 2 years ago

CoinCheung commented 2 years ago

Hi,

I am trying to learn about the code, and I find the following line: https://github.com/facebookresearch/MaskFormer/blob/da3e60d85fdeedcb31476b5edd7d328826ce56cc/mask_former/modeling/transformer/transformer.py#L70 The input tgt of the decoder is all zeros, and I see the all-zeros-tensor is used as input in the decoder layer: https://github.com/facebookresearch/MaskFormer/blob/da3e60d85fdeedcb31476b5edd7d328826ce56cc/mask_former/modeling/transformer/transformer.py#L272

Here tgt is all-zeros and the query_pos is a learnable embedding, which causes q and k to be non-zero tensor (same tensor in value as query_pos, but the tgt is still all-zeros(used as v). According to the computation rule of qkv attention, if v is all-zeros, the output of qkv would be all-zeros. Thus the self-attention module does not contribute to the model. Am I correct on this?

bowenc0221 commented 2 years ago

This is correct only for the first self-attention layer. tgt is no longer zero vector after cross-attention.

CoinCheung commented 2 years ago

Thanks for replying !!! There is another part of code that I cannot understand: https://github.com/facebookresearch/MaskFormer/blob/da3e60d85fdeedcb31476b5edd7d328826ce56cc/mask_former/modeling/transformer/transformer.py#L75 If we use default settings of batch_first=False for nn.MultiheadAttention, the above hs tensor should be LNE, where L is sequence length(num of queries here), N is batchsize and E is feature dimension. After the transpose(1,2), hs will become LEN. The batchsize will be the last dimension. However, according to this line: https://github.com/facebookresearch/MaskFormer/blob/da3e60d85fdeedcb31476b5edd7d328826ce56cc/mask_former/modeling/transformer/transformer_predictor.py#L130 The output hs should be a 4d tensor ? Would you please tell me what did I miss here ?