isl-org / DPT

Dense Prediction Transformers
MIT License
2.02k stars 258 forks source link

Question about the pos_embed #68

Closed kamiLight closed 2 years ago

kamiLight commented 2 years ago

I want to know why we need resize the pos_embed to (1, num_patches + self.num_tokens, embed_dim), rather than use nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)). Thanks!

def _resize_pos_embed(self, posemb, gs_h, gs_w): 
    posemb_tok, posemb_grid = (
        posemb[:, : self.start_index],
        posemb[0, self.start_index :],
    ) 
    gs_old = int(math.sqrt(len(posemb_grid))) 

    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 
    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 
    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)  
    posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 

    return posemb