InternLM / InternLM-XComposer

InternLM-XComposer2 is a groundbreaking vision-language large model (VLLM) excelling in free-form text-image composition and comprehension.
1.91k stars 120 forks source link

4khd-7b 多图sft时报错 #311

Open zws-2019 opened 1 month ago

zws-2019 commented 1 month ago

我输入了两张图像,shape: torch.Size([2, 3, 1680, 1008])

当我执行到: self.vit([image], self.plora_glb_GN, self.plora_sub_GN)

报错: RuntimeError: shape '[1, 3, 5, 336, 3, 336]' is invalid for input of size 10160640

用单张图片是不报错,两张时报错

plmsmile commented 1 month ago

同。 sub_img = img.reshape(1,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contiguous()

RuntimeError: shape '[1, 3, 3, 336, 4, 336]' is invalid for input of size 8128512

plmsmile commented 1 month ago

还有多图shape不一致的时候,需要resize到同一个shape才可以。我是修改了data_mix.py里Sample_dataset里对多图做了统一shape。

但还是会在build_mlp.py里出错。然后我又把 sub_image reshape的第一维改成cnt(单图是1,多图就是图片数量),后来就正常运行起来了。

sub_img = img.reshape(cnt,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contiguous()

image

zws-2019 commented 1 month ago

还有多图形状不一致的时候,需要resize到同一个形状才可以。我是修改了data_mix.py里Sample_dataset里对多图做了统一形状。

但还是会在build_mlp.py里出错。然后我又把sub_image reshape的第一维改成cnt(单图是1,多图就是图片数量),后来就正常运行起来了。

sub_img = img.reshape(cnt,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contigious()

图像

这样看起来是可以跑通 4khd模型的处理逻辑看起来不支持多图 比如这里只把第一个image_feature作为glb_img,如果我有多图,逻辑就会有问题

        for [h, w] in shapes:
            B_ = h*w
            glb_img = image_features[:1] ### 1, N, C
            glb_img = glb_img.reshape(1,H,H,C).reshape(1,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//2,H//2,4*C).contiguous()
            temp_glb_GN = sub_GN.repeat(1, H//2, 1, 1)
            glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,4*C)

            sub_img = image_features[1:1+B_] ### ?, N, C
            sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,4*C).contiguous()
            sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(0,1,3,2,4,5).reshape(1,h*12,w*12,4*C)
            temp_sub_GN = sub_GN.repeat(1, h*12, 1, 1)
            sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,4*C)

            output_imgs.append(torch.cat([glb_img, glb_GN, sub_img], dim=1))
            temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
            assert temp_len == output_imgs[-1].shape[1]
            output_len.append(temp_len)

            image_features = image_features[1+h*w:]