ssocean / AlphX-Code-For-DAR

粤港澳大湾区(黄埔)国际算法算例大赛-古籍文档图像识别与分析算法比赛 Alphx队源码
36 stars 3 forks source link

请问unet的监督信号是01吗 #1

Closed helloword12345678 closed 1 year ago

helloword12345678 commented 1 year ago

您好: 感谢分享,请问unet排序监督信号如何构建的?

ssocean commented 1 year ago

感谢关注,以下是相关代码~

def load_cnts_from_json(json_pth):
    '''
    读取指定路径对应的[轮廓]
    :return: [poly1,poly2,...,poly3]
    '''
    with open(json_pth, encoding='utf-8') as f:
        result = json.load(f)
    cnts = []
    for region in result['shapes']:
        points = region['points']
        poly = np.array(points).astype(np.int32).reshape((-1))
        poly = poly.reshape(-1, 2)
        cnts.append(poly)
    return cnts

def resize_contour(cnts, ori_size, rst_shape):
    '''
    原地操作函数,由于原图尺寸的变换将会导致标注信息的变换,该方法完成在图片尺寸变换时标注信息的同步转换。
    :return:
    '''
    o_h, o_w = ori_size
    r_h, r_w = rst_shape
    height_ratio = r_h / o_h
    width_ratio = r_w / o_w  # 计算出高度、宽度的放缩比例
    ratio_mat = [[width_ratio, 0], [0, height_ratio]]
    # print(points_to_poly(cnts).shape)
    return (np.array(cnts).astype(np.int32).reshape((-1)).reshape((-1, 2)) @ ratio_mat).astype(
        np.int32)  # n×2 矩阵乘 2×2

def preprocess(img, mask):
    """
    用于在加载数据集的时候对图像做预处理,不可被外部调用
    :param img:输入的图像
    :return:经预处理的图像
    """
    mask = mask.astype(np.uint16)
    segmap = SegmentationMapsOnImage(mask, shape=img.shape)
    seq = iaa.OneOf([
        iaa.Affine(scale={"x": (0.7, 1.3), "y": (0.7, 1.3)}, translate_percent=(0, 0.2), rotate=(-25, 25), cval=0,
                   mode='constant'),  # 仿射变换
        iaa.ShearX((-15, 15)),
        iaa.CropAndPad(percent=(-0.2, 0.5), pad_mode='constant', pad_cval=0, keep_size=True),  # 裁剪缩放
        iaa.PiecewiseAffine(scale=(0, 0.05), nb_rows=4, nb_cols=4, cval=0),  # 以控制点的方式随机形变
    ])
    img_aug, msk_aug = seq(image=np.array(img, dtype=np.uint8), segmentation_maps=segmap)
    return img_aug, msk_aug.get_arr()
class OrderDataset(Dataset):
    '''
    数据集加载类

    '''

    def __init__(self, base_dir):
        '''

        :param base_dir:存放着图像+labelme格式标注的文件夹路径
        '''
        self.base_dir = base_dir
        self.imgs_pth = get_files_pth(base_dir, 'jpg') + get_files_pth(base_dir, 'png')
        self.jsons_pth = get_files_pth(base_dir, 'json')
        assert len(self.imgs_pth) == len(self.jsons_pth), 'Num not the same'

    def __len__(self):
        '''
        返回数据集中包含的样本个数
        :return: 数据集中包含的样本个数
        '''
        return len(self.imgs_pth)

    def __getitem__(self, item):
        '''
        根据item,返回图片和它对应的标注图片
        :param item: 框架指定,请勿修改
        :return: 字典{'img':FloatTensor类型,输入图像,'mask':FloatTensor类型,对应不同实数值的GT,'M':mask的阈值操作结果}
        '''

        img_pth = self.imgs_pth[item]

        img_name = get_filename_from_pth(img_pth, False)

        json_pth = os.path.join(self.base_dir, img_name + '.json')

        polys = load_cnts_from_json(json_pth)
        img = cv2.imread(img_pth, 0)
        # 注意数据类型,否则uint8最大只能绘制255个
        mask = np.zeros(mask_size, dtype=np.float32)
        # 绘制GT(mask)
        for i, poly in enumerate(polys):
            poly = resize_contour(poly, img.shape[:2], mask.shape[:2])
            # i+2便于threshold阈值操作,大于1全部转为前景像素
            mask = cv2.drawContours(mask, [poly.reshape((-1, 2))], -1, i + 2, thickness=-1)

        '''
        我们最开始考虑像素值越小,区域越优先。而背景是0,背景最优先。这不合乎逻辑,
        因此我们起初尝试对网络输出pred做掩码操作,即pred = pred * M,将掩码pred与真实标签mask做L1smooth损失。
        但实际操作中效果不佳,还是直接L1效果稍微好点。
        '''
        _, M = cv2.threshold(mask.astype(np.uint8), 1, 255, cv2.THRESH_BINARY)

        # M归一化
        M = M / 255

        # mask归一化
        max_val = len(polys) + 1
        mask = mask / max_val

        img = M.astype(np.uint8)
        mask = mask.astype(np.float32)

        return {
            'image': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor),
            'M': torch.from_numpy(M.astype(np.uint8)).type(torch.FloatTensor)
        }
        pass
ssocean commented 1 year ago

简言之,渲染第i个区域像素值为i,随后将整幅mask归一化至0-1区间内。

helloword12345678 commented 1 year ago

请问这里的区域怎么定义的? 是一个多边形文本框?还是 文本框集合?

ssocean commented 1 year ago

一个多边形文本框 image_957 image_958 image_968 image_917 image_925 image_937 image_944 image_946