Open Kimsure opened 1 day ago
Yes, we employ the Category Early Rejection to reduce the number of dimension N. To restore the original number of categories, we apply processing to the mask after the final stage.
# https://github.com/xb534/SED/blob/main/sed/modeling/transformer/model.py#L694
valid_idx = aux_valid.clone()
valid_idx1 = aux0_valid.clone()
valid_idx1[aux0_valid] = aux1_valid
valid_idx[aux_valid] = valid_idx1
_, _, H, W = corr_embed.size()
output = torch.zeros([B, T, H, W]).cuda()
output[:, valid_idx] = corr_embed.sigmoid()
Thanks a lot for your reply.
I'm still confused about the code below. 1. mask = mask[:, :topK].reshape(-1)
can reduce the dimension via topK, is torch.bincount()
used for which process? 2. How to control the channel dimension at each FAM+SFM
decoder block? 3. Does the K set = 8 at every decoder block?
# https://github.com/xb534/SED/blob/main/sed/modeling/transformer/model.py#L703
def get_valid_idx_from_mask(self, mask, topK=2):
T = mask.size(1)
mask = torch.argsort(mask, dim=1, descending=True)
mask = mask[:, :topK].reshape(-1)
valid_idx = torch.bincount(mask, minlength=T) > 0
return valid_idx
def fast_conv_decoder(self, x, text_guidance, guidance, corr_guidance, topK):
B = x.shape[0]
T = x.shape[2]
corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
mask_aux = self.head0(corr_embed)
mask_aux = rearrange(mask_aux, '(B T) () H W -> B T H W', B=B)
aux_valid = self.get_valid_idx_from_mask(mask_aux, topK)
corr_embed = self.decoder1(corr_embed[aux_valid], text_guidance[0], guidance[0], corr_guidance[0][aux_valid])
mask_aux0 = self.head1(corr_embed)
...
valid_idx = aux_valid.clone()
valid_idx1 = aux0_valid.clone()
valid_idx1[aux0_valid] = aux1_valid
valid_idx[aux_valid] = valid_idx1
Thanks for your impressive work!
I'm confused about the top-k strategy on Category Early Rejection. Does it change the number of dimension N, e.g. from N to N_{l1} ? if so, then dimension N is reduced. I wonder how is this operation in the next layer handled and how does the dimension change in each layer?