microsoft / unilm

Large-scale Self-supervised Pre-training Across Tasks, Languages, and Modalities
https://aka.ms/GeneralAI
MIT License
19.14k stars 2.44k forks source link

When using run_class_finetuning fine-tuning, a key is indicated(rel_pos_bias.relative_position_bias_table) #1332

Open zhanglaoban-kk opened 9 months ago

zhanglaoban-kk commented 9 months ago

Hello,team KeyError Traceback (most recent call last) ~\AppData\Local\Temp\ipykernel_6004\1581778792.py in ----> 1 model = beit_base_patch16_224(pretrained = False)

~\AppData\Local\Temp\ipykernel_6004\1659408983.py in beit_base_patch16_224(pretrained, **kwargs) 30 rel_pos_bias = checkpoint_model[key] 31 src_num_pos, num_attn_heads = rel_pos_bias.size() ---> 32 dst_numpos, = model.state_dict()[key].size() 33 dst_patch_shape = model.patch_embed.patch_shape 34 if dst_patch_shape[0] != dst_patch_shape[1]:

KeyError: 'rel_pos_bias.relative_position_bias_table' I did self-supervised learning using run_beit_pretraining and then loaded the pre-training weights as provided in run_class_finetuning, but I don't know what's causing the problem.

donglixp commented 8 months ago

Could you give the specific command that produced the above error?

zhanglaoban-kk commented 8 months ago

Since I'm going to perform my own classification task, I first loaded my own unlabeled dataset for self-supervised learning using run_beit_pretraining to get the pre-training weights, and then extracted the weight-loading module from run_class_finetuning and put it into the function beit_base_patch16_224 . predModelPath = 'F:/beit/shuju/output/checkpoint-99.pth' def beit_base_patch16_224(pretrained=False, kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), kwargs) model.default_cfg = _cfg()

checkpoint = torch.load(predModelPath, map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]

if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
    print("Expand the shared relative position embedding to each transformer block. ")
    num_layers = model.get_num_layers()
    rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"]
    for i in range(num_layers):
        checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()

    checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")

all_keys = list(checkpoint_model.keys())
for key in all_keys:
    if "relative_position_index" in key:
        checkpoint_model.pop(key)

    if "relative_position_bias_table" in key:
        rel_pos_bias = checkpoint_model[key]
        src_num_pos, num_attn_heads = rel_pos_bias.size()
        dst_num_pos, _ = model.state_dict()[key].size()
        dst_patch_shape = model.patch_embed.patch_shape
        if dst_patch_shape[0] != dst_patch_shape[1]:
            raise NotImplementedError()
        num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
        src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
        dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
        if src_size != dst_size:
            print("Position interpolate for %s from %dx%d to %dx%d" % (
                key, src_size, src_size, dst_size, dst_size))
            extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
            rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]

            def geometric_progression(a, r, n):
                return a * (1.0 - r ** n) / (1.0 - r)

            left, right = 1.01, 1.5
            while right - left > 1e-6:
                q = (left + right) / 2.0
                gp = geometric_progression(1, q, src_size // 2)
                if gp > dst_size // 2:
                     right = q
                else:
                    left = q

            dis = []
            cur = 1
            for i in range(src_size // 2):
                dis.append(cur)
                cur += q ** (i + 1)

            r_ids = [-_ for _ in reversed(dis)]

            x = r_ids + [0] + dis
            y = r_ids + [0] + dis

            t = dst_size // 2.0
            dx = np.arange(-t, t + 0.1, 1.0)
            dy = np.arange(-t, t + 0.1, 1.0)

            print("Original positions = %s" % str(x))
            print("Target positions = %s" % str(dx))

            all_rel_pos_bias = []

            for i in range(num_attn_heads):
                z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
                f = interpolate.interp2d(x, y, z, kind='cubic')
                all_rel_pos_bias.append(
                    torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))

            rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)

            new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
            checkpoint_model[key] = new_rel_pos_bias

if 'pos_embed' in checkpoint_model:
    pos_embed_checkpoint = checkpoint_model['pos_embed']
    embedding_size = pos_embed_checkpoint.shape[-1]
    num_patches = model.patch_embed.num_patches
    num_extra_tokens = model.pos_embed.shape[-2] - num_patches
    # height (== width) for the checkpoint position embedding
    orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
    # height (== width) for the new position embedding
    new_size = int(num_patches ** 0.5)
    # class_token and dist_token are kept unchanged
    if orig_size != new_size:
        print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        # only the position tokens are interpolated
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(
            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        checkpoint_model['pos_embed'] = new_pos_embed

utils.load_state_dict(model, checkpoint_model)
return model

But when I run model = beit_base_patch16_224(pretrained = False), does it report the above error, I don't know what it's because of, and I didn't change the model structure .

addf400 commented 8 months ago

@zhanglaoban-kk , Can you provide more details? Such as 'use_rel_pos_bias', 'use_shaerd_rel_pos_bias', ‘use_rel_pos_bias’, 'use_abs_pos_emb' when create torch.nn object for pretraining and fine-tuning.