OFA-Sys / ONE-PEACE

A general representation model across vision, audio, language modalities. Paper: ONE-PEACE: Exploring One General Representation Model Toward Unlimited Modalities
Apache License 2.0
964 stars 63 forks source link

KeyError: 'image_adapter.pos_embed' #48

Open xxxxyliu opened 9 months ago

xxxxyliu commented 9 months ago
2024-01-28 15:51:36,396 - mmseg - INFO - load checkpoint from local path: /data/onepeace_seg_cocostuff2ade20k.pth
Traceback (most recent call last):
  File "/root/ONE-PEACE/one_peace_vision/seg/train.py", line 243, in <module>
    main()
  File "/root/ONE-PEACE/one_peace_vision/seg/train.py", line 203, in main
    model.init_weights()
  File "/opt/conda/envs/onepeace/lib/python3.8/site-packages/mmcv/runner/base_module.py", line 116, in init_weights
    m.init_weights()
  File "/root/ONE-PEACE/one_peace_vision/seg/mmseg_custom/models/backbones/onepeace.py", line 571, in init_weights
    state_dict = self.resize_abs_pos_embed(model)
  File "/root/ONE-PEACE/one_peace_vision/seg/mmseg_custom/models/backbones/onepeace.py", line 467, in resize_abs_pos_embed
    pos_embed_checkpoint = checkpoint['image_adapter.pos_embed']
KeyError: 'image_adapter.pos_embed'

I haven't changed the model structure, but I'm encountering an error when using pre-trained weights.

AndrewTKent commented 8 months ago

@xxxxyliu Same situation for me - let me know if you figured out the solution.

AndrewTKent commented 1 month ago

@xxxxyliu Here we go:


def resize_abs_pos_embed(self, checkpoint):
    # Check for the correct key in the checkpoint
    pos_embed_key = 'backbone.image_adapter.pos_embed' if 'backbone.image_adapter.pos_embed' in checkpoint else 'image_adapter.pos_embed'

    pos_embed_checkpoint = checkpoint[pos_embed_key]
    embedding_size = pos_embed_checkpoint.shape[-1]
    bucket_size = self.image_adapter.bucket_size
    num_patches = bucket_size ** 2
    num_extra_tokens = self.image_adapter.pos_embed.shape[-2] - num_patches

    # Calculate original and new sizes for position embedding
    orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
    new_size = int(num_patches ** 0.5)

    # Keep class_token and dist_token unchanged
    rank, _ = get_dist_info()

    if orig_size != new_size:
        if rank == 0:
            print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size}")

        extra_tokens = pos_embed_checkpoint[:num_extra_tokens]

        # Only the position tokens are interpolated
        pos_tokens = pos_embed_checkpoint[num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)

        pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2)

        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
        checkpoint[pos_embed_key] = new_pos_embed

    return checkpoint