alibaba / AliceMind

ALIbaba's Collection of Encoder-decoders from MinD (Machine IntelligeNce of Damo) Lab
Apache License 2.0
1.98k stars 291 forks source link

mplug中两处代码错误问题 #91

Open Norman-Ou opened 1 year ago

Norman-Ou commented 1 year ago

AliceMind/mPLUG /caption_mplug.py中234-239行中,无论是visual encoder用了ViT-B-16还是ViT-L-14,在根据输入图像分辨率更改positional embedding时,创建的pos_embedd的dim=1都是768。

        # reshape positional embedding to accomodate for image resolution change
        if config["clip_name"] == "ViT-B-16":
            num_patches = int(config["image_res"] * config["image_res"]/(16*16))
        elif config["clip_name"] == "ViT-L-14":
            num_patches = int(config["image_res"] * config["image_res"]/(14*14))
        pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768).float())

        pos_embed = resize_pos_embed(state_dict['visual_encoder.visual.positional_embedding'].unsqueeze(0),
                                                   pos_embed.unsqueeze(0))
        state_dict['visual_encoder.visual.positional_embedding'] = pos_embed

而在初始化MPLUG模型中初始化CLIP的时,即调用AliceMind/models /visual_transformers.py中初始化CLIP代码时,positional embedding的dim=1的值时根据visual encoder用了ViT-B-16还是ViT-L-14而决定的,B-16时dim=1为768,L-14时dim=1为1024。

def initialize_clip(config, num_patches=240):
    from models.clip import clip
    if config["clip_name"] == "ViT-B-16":
        clip_model, preprocess = clip.load("ViT-B-16.tar", jit=False)
        num_patches = int(config['image_res']*config['image_res']/(16*16))
        pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768).float())
    elif config["clip_name"] == "ViT-L-14":
        clip_model, preprocess = clip.load("ViT-L-14.tar", jit=False)
        num_patches = int(config['image_res']*config['image_res']/(14*14))
        pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 1024).float())    
    pos_embed.weight = resize_pos_embed(clip_model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0))
    clip_model.visual.positional_embedding = pos_embed
    return clip_model, preprocess

两代码的区别造成 RuntimeError: Error(s) in loading state_dict for MPLUG: size mismatch for visual_encoder.visual.positional_embedding: copying a param with shape torch.Size([x, 768]) from checkpoint, the shape in current model is torch.Size([x, 1024]).的错误。

如有错误敬请指正