Open xxxxyliu opened 9 months ago
@xxxxyliu Same situation for me - let me know if you figured out the solution.
@xxxxyliu Here we go:
def resize_abs_pos_embed(self, checkpoint):
# Check for the correct key in the checkpoint
pos_embed_key = 'backbone.image_adapter.pos_embed' if 'backbone.image_adapter.pos_embed' in checkpoint else 'image_adapter.pos_embed'
pos_embed_checkpoint = checkpoint[pos_embed_key]
embedding_size = pos_embed_checkpoint.shape[-1]
bucket_size = self.image_adapter.bucket_size
num_patches = bucket_size ** 2
num_extra_tokens = self.image_adapter.pos_embed.shape[-2] - num_patches
# Calculate original and new sizes for position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
new_size = int(num_patches ** 0.5)
# Keep class_token and dist_token unchanged
rank, _ = get_dist_info()
if orig_size != new_size:
if rank == 0:
print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size}")
extra_tokens = pos_embed_checkpoint[:num_extra_tokens]
# Only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
checkpoint[pos_embed_key] = new_pos_embed
return checkpoint
I haven't changed the model structure, but I'm encountering an error when using pre-trained weights.