whai362 / PSENet

Official Pytorch implementations of PSENet.
Apache License 2.0
1.17k stars 344 forks source link

测试没有图像输出 #175

Open xianyu-123 opened 3 years ago

xianyu-123 commented 3 years ago

测试后只生成了位置文本,没有生成结果图像,这个应该怎么修改呢?

lrfighting commented 3 years ago

你好,请问他的测试数据集放哪

xianyu-123 commented 3 years ago

用的他的命令行操作的。我最后跑通的是TensorFlow版本的那个

lrfighting commented 3 years ago

用他的命令行操作的。我最后跑通是TensorFlow版本的那个人

他这个跑不通吗,我测试就失败了 用他的命令行报错没有那个文件,我也找不到他那个文件在哪定义了,

xianyu-123 commented 3 years ago

用他的命令行操作的。我最后跑通是TensorFlow版本的那个人

他这个跑不通吗,我测试就失败了 用他的命令行报错没有那个文件,我也找不到他那个文件在哪定义了,

在dataset文件夹下psenet文件夹里面,我测试的是IC15,然后打开这个文件,在里面修改就可以了。

lrfighting commented 3 years ago

用他的命令行操作的。我最后跑通是TensorFlow版本的那个人

他跑不通吗,我测试就失败了,用他的这个命令行报错没有那个文件,我也找不到他那个文件在哪定义了,

在dataset文件夹下psenet文件夹里面,我测试的是IC15,然后打开这个文件,在里面修改就可以了。 好的,谢谢 ,

haode

lrfighting commented 3 years ago

用他的命令行操作的。我最后跑通是TensorFlow版本的那个人

他这个跑不通吗,我测试就失败了 用他的命令行报错没有那个文件,我也找不到他那个文件在哪定义了,

在dataset文件夹下psenet文件夹里面,我测试的是IC15,然后打开这个文件,在里面修改就可以了。

大神你好,请问你用他的训练命令了吗,训练有么有报错呢,我的出现一个C就终止了,这是什么问题呢

HCMY commented 2 years ago

测试后只生成了位置文本,没有生成结果图像,这个应该怎么修改呢?

需要你自己把模型的预测结果,也就是bbox处理,放到图像上


import numpy as np
from PIL import Image

import torchvision.transforms as transforms
import torch
from mmcv import Config
from .models import build_model
from .models.utils import fuse_module
from .dataset.psenet import psenet_ctw

class OcrTextDetector(object):
    def __init__(self, ckpt_path, config_path, device='cpu'):
        self.ckpt_path = ckpt_path
        self.cfg_path = config_path
        self.device = device

        self.model = None
        self.cfg = None

    def build_model(self):
        cfg = Config.fromfile(self.cfg_path)

        for d in [cfg, cfg.data.test]:
            d.update(dict(
                report_speed=False
            ))

        self.cfg = cfg

        model = build_model(self.cfg.model)
        model.to(self.device)

        checkpoint = torch.load(self.ckpt_path, map_location=self.device)
        d = dict()
        for key,value in checkpoint['state_dict'].items():
            tmp = key[7:]
            d[tmp] = value
        model.load_state_dict(d)
        model = fuse_module(model)

        model.eval()

        self.model = model

        return self

    def preprocess_img(self, img_path):
        img = psenet_ctw.get_img(img_path=img_path, read_type='pil')

        img_meta = dict(
            org_img_size=[np.array(img.shape[:2])]
        )

        img = psenet_ctw.scale_aligned_short(img)
        img_meta.update(dict(
            img_size=[np.array(img.shape[:2])]
        ))

        img = Image.fromarray(img)
        img = img.convert('RGB')
        img = transforms.ToTensor()(img)
        img = transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])(img)

        return img, img_meta

    def predict(self, img_path):
        img, img_meta = self.preprocess_img(img_path)
        img = torch.unsqueeze(img, 0)
        outputs = self.model(img, img_metas=img_meta, cfg=self.cfg)

        return outputs

if __name__ == "__main__":
    ckpt_path = './checkpoints/psenet_r50_ctw_finetune/checkpoint.pth'
    cfg = './config/psenet/psenet_r50_ctw_finetune.py'
    img_path = '../../dataset/ctw1500/train_images/0002.jpg'
    ocr = OcrTextDetector(ckpt_path, cfg).build_model().predict(img_path)
def draw_bbox(bboxs,img):
    bboxs_res = []
    for bbox in bboxs:
        bbox = np.reshape(bbox,(4,2))
        cv2.drawContours(img, [bbox],-1, (0, 255, 0), 2)
        bboxs_res.append(bbox)
    return bboxs_res, img

detector = OcrTextDetector(ckpt_path, cfg).build_model().predict(img_path)
box = detector.predict(img_path)
 img = cv2.imread(img_path)
 bboxs_res, box_img = draw_bbox(box['bboxes'], img)
 plt.imshow(box_img )
xianyu-123 commented 2 years ago

您好,您的来信我已经收到,感谢您的来信!谢谢

Xiuxiu21 commented 1 year ago

请问可以提供预训练模型吗