Open zws-2019 opened 1 month ago
同。 sub_img = img.reshape(1,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contiguous()
RuntimeError: shape '[1, 3, 3, 336, 4, 336]' is invalid for input of size 8128512
还有多图shape不一致的时候,需要resize到同一个shape才可以。我是修改了data_mix.py里Sample_dataset里对多图做了统一shape。
但还是会在build_mlp.py里出错。然后我又把 sub_image reshape的第一维改成cnt(单图是1,多图就是图片数量),后来就正常运行起来了。
sub_img = img.reshape(cnt,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contiguous()
还有多图形状不一致的时候,需要resize到同一个形状才可以。我是修改了data_mix.py里Sample_dataset里对多图做了统一形状。
但还是会在build_mlp.py里出错。然后我又把sub_image reshape的第一维改成cnt(单图是1,多图就是图片数量),后来就正常运行起来了。
sub_img = img.reshape(cnt,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contigious()
这样看起来是可以跑通 4khd模型的处理逻辑看起来不支持多图 比如这里只把第一个image_feature作为glb_img,如果我有多图,逻辑就会有问题
for [h, w] in shapes:
B_ = h*w
glb_img = image_features[:1] ### 1, N, C
glb_img = glb_img.reshape(1,H,H,C).reshape(1,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//2,H//2,4*C).contiguous()
temp_glb_GN = sub_GN.repeat(1, H//2, 1, 1)
glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,4*C)
sub_img = image_features[1:1+B_] ### ?, N, C
sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,4*C).contiguous()
sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(0,1,3,2,4,5).reshape(1,h*12,w*12,4*C)
temp_sub_GN = sub_GN.repeat(1, h*12, 1, 1)
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,4*C)
output_imgs.append(torch.cat([glb_img, glb_GN, sub_img], dim=1))
temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
assert temp_len == output_imgs[-1].shape[1]
output_len.append(temp_len)
image_features = image_features[1+h*w:]
我输入了两张图像,shape: torch.Size([2, 3, 1680, 1008])
当我执行到: self.vit([image], self.plora_glb_GN, self.plora_sub_GN)
报错: RuntimeError: shape '[1, 3, 5, 336, 3, 336]' is invalid for input of size 10160640
用单张图片是不报错,两张时报错