whai362 / pan_pp.pytorch

Official implementations of PSENet, PAN and PAN++.
Apache License 2.0
439 stars 90 forks source link

About test.py #95

Open huahuabai opened 2 years ago

huahuabai commented 2 years ago

Hello, author. When I execute test according to the command you gave me, I have the following problem image How can I solve this problem?

lyx-0213 commented 2 years ago

Hi, were you able to solve this issue?

Ycxyue commented 2 years ago

在dataset/pan_ctw.py的prepare_test_data函数中有关于img_meta的字段写入,只需要添加

def prepare_test_data(self, index):
        img_path = self.img_paths[index]
        # print('img_path:', img_path)
        img_name = img_path.split('/')[-1]#修改的地方
        img = get_img(img_path, self.read_type)
        img_meta = dict(org_img_size=np.array(img.shape[:2]))

        img = scale_aligned_short(img, self.short_size)
        img_meta.update(dict(img_size=np.array(img.shape[:2])))
        img_meta.update(dict(img_name=img_name))#add
        img_meta.update(dict(img_path=img_path))#add

        img = Image.fromarray(img)
        img = img.convert('RGB')
        img = transforms.ToTensor()(img)
        img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])(img)

        data = dict(imgs=img, img_metas=img_meta)

        return data 

修改完后就可以跑通test.py了

lyx-0213 commented 2 years ago

成啦!谢!谢!你!

huahuabai commented 2 years ago

在dataset/pan_ctw.py的prepare_test_data函数中有关于img_meta的字段写入,只需要添加

def prepare_test_data(self, index):
        img_path = self.img_paths[index]
        # print('img_path:', img_path)
        img_name = img_path.split('/')[-1]#修改的地方
        img = get_img(img_path, self.read_type)
        img_meta = dict(org_img_size=np.array(img.shape[:2]))

        img = scale_aligned_short(img, self.short_size)
        img_meta.update(dict(img_size=np.array(img.shape[:2])))
        img_meta.update(dict(img_name=img_name))#add
        img_meta.update(dict(img_path=img_path))#add

        img = Image.fromarray(img)
        img = img.convert('RGB')
        img = transforms.ToTensor()(img)
        img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])(img)

        data = dict(imgs=img, img_metas=img_meta)

        return data 

修改完后就可以跑通test.py了

感恩啊!