Sense-X / UniFormer

[ICLR2022] official implementation of UniFormer
Apache License 2.0
812 stars 111 forks source link

KeyError: 'pos_embed' when fine-tuning #88

Closed RooKichenn closed 1 year ago

RooKichenn commented 1 year ago

When I use the uniformer_base_ls model to fine tune on my own dataset, the code reports an error:

Traceback (most recent call last): File "main.py", line 502, in <module> main(args) File "main.py", line 278, in main pos_embed_checkpoint = checkpoint_model['pos_embed'] KeyError: 'pos_embed' Then I checked the pre-training weights and found that the key pos_embed is not listed separately: 'blocks1.0.pos_embed.weight', 'blocks1.0.pos_embed.bias' Finally I commented out these lines of code and the program worked. What is this problem? Can you help me please?

# interpolate position embedding
print(checkpoint_model.keys())
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
Andy1621 commented 1 year ago

Sorry for the late reply. Your practice is right! The above code for fine-tuning is copied from DeiT, which is used for interpolating the absolute position embedding in ViT. Since our UniFormer uses dynamic position embedding, you can simply comment out the code!

RooKichenn commented 1 year ago

Sorry I haven't read the DeiT source code, thank you very much for answering this question for me!

Andy1621 commented 1 year ago

As there is no more activity, I am closing the issue, don't hesitate to reopen it if necessary.