Closed kamiLight closed 2 years ago
I want to know why we need resize the pos_embed to (1, num_patches + self.num_tokens, embed_dim), rather than use nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)). Thanks!
def _resize_pos_embed(self, posemb, gs_h, gs_w): posemb_tok, posemb_grid = ( posemb[:, : self.start_index], posemb[0, self.start_index :], ) gs_old = int(math.sqrt(len(posemb_grid))) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb
I want to know why we need resize the pos_embed to (1, num_patches + self.num_tokens, embed_dim), rather than use nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)). Thanks!