xb534 / SED

[CVPR2024] Official Pytorch Implementation of SED: A Simple Encoder-Decoder for Open-Vocabulary Semantic Segmentation.
Apache License 2.0
119 stars 9 forks source link

Question about CER #25

Open Kimsure opened 1 day ago

Kimsure commented 1 day ago

Thanks for your impressive work!

I'm confused about the top-k strategy on Category Early Rejection. Does it change the number of dimension N, e.g. from N to N_{l1} ? if so, then dimension N is reduced. I wonder how is this operation in the next layer handled and how does the dimension change in each layer?

xb534 commented 6 hours ago

Yes, we employ the Category Early Rejection to reduce the number of dimension N. To restore the original number of categories, we apply processing to the mask after the final stage.

# https://github.com/xb534/SED/blob/main/sed/modeling/transformer/model.py#L694
        valid_idx = aux_valid.clone()
        valid_idx1 = aux0_valid.clone()
        valid_idx1[aux0_valid] = aux1_valid
        valid_idx[aux_valid] = valid_idx1
        _, _, H, W = corr_embed.size()
        output = torch.zeros([B, T, H, W]).cuda()
        output[:, valid_idx] = corr_embed.sigmoid()
Kimsure commented 55 minutes ago

Thanks a lot for your reply.

I'm still confused about the code below. 1. mask = mask[:, :topK].reshape(-1) can reduce the dimension via topK, is torch.bincount() used for which process? 2. How to control the channel dimension at each FAM+SFM decoder block? 3. Does the K set = 8 at every decoder block?

# https://github.com/xb534/SED/blob/main/sed/modeling/transformer/model.py#L703
def get_valid_idx_from_mask(self, mask, topK=2):
        T = mask.size(1)
        mask = torch.argsort(mask, dim=1, descending=True)
        mask = mask[:, :topK].reshape(-1)
        valid_idx = torch.bincount(mask, minlength=T) > 0
        return valid_idx

def fast_conv_decoder(self, x, text_guidance, guidance, corr_guidance, topK):
        B = x.shape[0]
        T = x.shape[2]
        corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
        mask_aux = self.head0(corr_embed)
        mask_aux = rearrange(mask_aux, '(B T) () H W -> B T H W', B=B)
        aux_valid = self.get_valid_idx_from_mask(mask_aux, topK)
        corr_embed = self.decoder1(corr_embed[aux_valid], text_guidance[0], guidance[0], corr_guidance[0][aux_valid])
        mask_aux0 = self.head1(corr_embed)
        ...
        valid_idx = aux_valid.clone()
        valid_idx1 = aux0_valid.clone()
        valid_idx1[aux0_valid] = aux1_valid
        valid_idx[aux_valid] = valid_idx1