FangShancheng / ABINet

Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition
Other
421 stars 72 forks source link

请问训练中文怎么改字典路径,还需要改别的代码吗 #20

Open aishangmaxiaoming opened 3 years ago

FangShancheng commented 3 years ago

您好,训练中文模型得在该仓库的基础上做一些额外的适配,暂时还没时间对比中文版本具体需要那些地方的修改,除了字典以外,至少有数据的处理部分需要修改,dataset.py的修改可参见如下

def strQ2B(ustring):
    ss = []
    for s in ustring:
        rstring = ""
        for uchar in s:
            inside_code = ord(uchar)
            if inside_code == 12288: 
                inside_code = 32
            elif (inside_code >= 65281 and inside_code <= 65374): 
                inside_code -= 65248
            rstring += chr(inside_code)
        ss.append(rstring)
    return ''.join(ss)

class ImageDataset(Dataset):
    "`ImageDataset` read data from LMDB database."

    def __init__(self,
                 path:PathOrStr,
                 is_training:bool=True,
                 img_h:int=32,
                 img_w:int=100,
                 max_length:int=25,
                 check_length:bool=True,
                 case_sensitive:bool=False,
                 charset_path:str='data/charset_36.txt',
                 convert_mode:str='RGB',
                 data_aug:bool=True,
                 deteriorate_ratio:float=0.,
                 multiscales:bool=True,
                 one_hot_y:bool=True,
                 return_idx:bool=False,
                 return_raw:bool=False,
                 **kwargs):
        self.path, self.name = Path(path), Path(path).name
        assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory."
        self.convert_mode, self.check_length = convert_mode, check_length
        self.img_h, self.img_w = img_h, img_w
        self.max_length, self.one_hot_y = max_length, one_hot_y
        self.return_idx, self.return_raw = return_idx, return_raw
        self.case_sensitive, self.is_training = case_sensitive, is_training
        self.data_aug, self.multiscales = data_aug, multiscales
        self.charset = CharsetMapper(charset_path, max_length=max_length+1)
        self.c = self.charset.num_classes

        self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
        assert self.env, f'Cannot open LMDB dataset from {path}.'
        with self.env.begin(write=False) as txn:
            self.length = int(txn.get('num-samples'.encode()))

        if self.is_training and self.data_aug:
            self.augment_tfs = transforms.Compose([
                CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
                CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
                CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25)
            ])
        self.totensor = transforms.ToTensor()

    def __len__(self): return self.length

    def _next_image(self, index):
        next_index = random.randint(0, len(self) - 1)
        return self.get(next_index)

    def _check_image(self, x, pixels=6):
        if x.size[0] <= pixels or x.size[1] <= pixels: return False
        else: return True

    def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT): 
        def _resize_ratio(img, ratio, fix_h=True):
            if ratio * self.img_w < self.img_h:
                if fix_h: trg_h = self.img_h
                else: trg_h = int(ratio * self.img_w)
                trg_w = self.img_w
            else: trg_h, trg_w = self.img_h, int(self.img_h / ratio)
            img = cv2.resize(img, (trg_w, trg_h))
            pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2
            top, bottom = math.ceil(pad_h), math.floor(pad_h)
            left, right = math.ceil(pad_w), math.floor(pad_w)
            img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType)
            return img

        if self.is_training: 
            if random.random() < 0.5:
                base, maxh, maxw = self.img_h, self.img_h, self.img_w
                h, w = random.randint(base, maxh), random.randint(base, maxw)
                return _resize_ratio(img, h/w)
            else: return _resize_ratio(img, img.shape[0] / img.shape[1])  # keep aspect ratio
        else:  return _resize_ratio(img, img.shape[0] / img.shape[1])  # keep aspect ratio

    def resize(self, img):
        if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE)
        else: return cv2.resize(img, (self.img_w, self.img_h))

    def get(self, idx):
        with self.env.begin(write=False) as txn:
            image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}'
            try:
                label = str(txn.get(label_key.encode()), 'utf-8')  # label
                label = label.replace(' ', '').replace('\u3000', '')
                label = strQ2B(label)
                # label = re.sub('[^0-9a-zA-Z]+', '', label)
                if self.check_length and self.max_length > 0:
                    if len(label) > self.max_length or len(label) <= 0:
                        #logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}')
                        return self._next_image(idx)
                label = label[:self.max_length]

                imgbuf = txn.get(image_key.encode())  # image
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
                    image = PIL.Image.open(buf).convert(self.convert_mode)
                if self.is_training and not self._check_image(image):
                    #logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}')
                    return self._next_image(idx)
            except:
                import traceback
                traceback.print_exc()
                logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}')
                return self._next_image(idx)
            return image, label, idx

    def _process_training(self, image):
        if self.data_aug: image = self.augment_tfs(image)
        image = self.resize(np.array(image))
        return image

    def _process_test(self, image):
        return self.resize(np.array(image)) # TODO:move is_training to here

    def __getitem__(self, idx):
        image, text, idx_new = self.get(idx)
        if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.'

        if self.is_training: image = self._process_training(image)
        else: image = self._process_test(image)
        if self.return_raw: return image, text
        image = self.totensor(image)

        length = tensor(len(text) + 1).to(dtype=torch.long)  # one for end token
        strict = False if self.is_training else True
        label = self.charset.get_labels(text, case_sensitive=self.case_sensitive, strict=strict)
        if label is None:
            logging.warning(f'Not found in charset. Skip this text.: {self.name}, {idx}, {text}, {len(text)}')
            next_idx = random.randint(0, len(self) - 1)
            return self[next_idx]

        label = tensor(label).to(dtype=torch.long)
        if self.one_hot_y: label = onehot(label, self.charset.num_classes)

        if self.return_idx: y = [label, length, idx_new]
        else: y = [label, length]
        return image, y

class TextDataset(Dataset):
    def __init__(self,
                 path:PathOrStr, 
                 delimiter:str='\t',
                 max_length:int=25, 
                 charset_path:str='data/charset_36.txt', 
                 case_sensitive=False, 
                 one_hot_x=True,
                 one_hot_y=True,
                 is_training=True,
                 smooth_label=False,
                 smooth_factor=0.2,
                 use_sm=False,
                 **kwargs):
        self.path = Path(path)
        self.max_length = max_length
        self.case_sensitive, self.use_sm = case_sensitive, use_sm
        self.smooth_factor, self.smooth_label = smooth_factor, smooth_label
        self.charset = CharsetMapper(charset_path, max_length=max_length+1)
        self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training
        if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset)

        dtype = {'inp': str, 'gt': str}
        self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False)
        self.inp_col, self.gt_col = 0, 1

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        text_x = self.df.iloc[idx, self.inp_col]
        text_x = text_x.replace(' ', '')
        text_x = strQ2B(text_x)
        text_x = text_x[:self.max_length]

        # text_x = re.sub('[^0-9a-zA-Z]+', '', text_x)
        # if not self.case_sensitive: text_x = text_x.lower()
        if self.is_training and self.use_sm: text_x = self.sm(text_x)

        length_x = tensor(len(text_x) + 1).to(dtype=torch.long)  # one for end token
        strict = False if self.is_training else True
        label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive, strict=strict)
        if label_x is None:
            next_idx = random.randint(0, len(self) - 1)
            label_x = self[next_idx]

        label_x = tensor(label_x)
        if self.one_hot_x:
            label_x = onehot(label_x, self.charset.num_classes)
            if self.is_training and self.smooth_label: 
                label_x = torch.stack([self.prob_smooth_label(l) for l in label_x])
        x =  [label_x, length_x]

        # text_y = self.df.iloc[idx, self.gt_col]
        # text_y = text_y.replace(' ', '')
        # text_y = strQ2B(text_y)
        # text_y = text_y[:self.max_length]

        # # text_y = re.sub('[^0-9a-zA-Z]+', '', text_y)
        # # if not self.case_sensitive: text_y = text_y.lower()
        # length_y = tensor(len(text_y) + 1).to(dtype=torch.long)  # one for end token
        # label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive)
        # label_y = tensor(label_y)
        # if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes)
        # y = [label_y, length_y]

        return x, x

    def prob_smooth_label(self, one_hot):
        one_hot = one_hot.float()
        delta = torch.rand([]) * self.smooth_factor
        num_classes = len(one_hot)
        noise = torch.rand(num_classes)
        noise = noise / noise.sum() * delta
        one_hot = one_hot * (1 - delta) + noise
        return one_hot
lyc728 commented 2 years ago

_No description provided.请问你后面有再尝试过用中文进行训练吗?效果如何呢?

lyc728 commented 2 years ago

您好,训练中文模型得在该仓库的基础上做一些额外的适配,暂时还没时间对比中文版本具体需要那些地方的修改,除了字典以外,至少有数据的处理部分需要修改,dataset.py的修改可参见如下

def strQ2B(ustring):
    ss = []
    for s in ustring:
        rstring = ""
        for uchar in s:
            inside_code = ord(uchar)
            if inside_code == 12288: 
                inside_code = 32
            elif (inside_code >= 65281 and inside_code <= 65374): 
                inside_code -= 65248
            rstring += chr(inside_code)
        ss.append(rstring)
    return ''.join(ss)

class ImageDataset(Dataset):
    "`ImageDataset` read data from LMDB database."

    def __init__(self,
                 path:PathOrStr,
                 is_training:bool=True,
                 img_h:int=32,
                 img_w:int=100,
                 max_length:int=25,
                 check_length:bool=True,
                 case_sensitive:bool=False,
                 charset_path:str='data/charset_36.txt',
                 convert_mode:str='RGB',
                 data_aug:bool=True,
                 deteriorate_ratio:float=0.,
                 multiscales:bool=True,
                 one_hot_y:bool=True,
                 return_idx:bool=False,
                 return_raw:bool=False,
                 **kwargs):
        self.path, self.name = Path(path), Path(path).name
        assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory."
        self.convert_mode, self.check_length = convert_mode, check_length
        self.img_h, self.img_w = img_h, img_w
        self.max_length, self.one_hot_y = max_length, one_hot_y
        self.return_idx, self.return_raw = return_idx, return_raw
        self.case_sensitive, self.is_training = case_sensitive, is_training
        self.data_aug, self.multiscales = data_aug, multiscales
        self.charset = CharsetMapper(charset_path, max_length=max_length+1)
        self.c = self.charset.num_classes

        self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
        assert self.env, f'Cannot open LMDB dataset from {path}.'
        with self.env.begin(write=False) as txn:
            self.length = int(txn.get('num-samples'.encode()))

        if self.is_training and self.data_aug:
            self.augment_tfs = transforms.Compose([
                CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
                CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
                CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25)
            ])
        self.totensor = transforms.ToTensor()

    def __len__(self): return self.length

    def _next_image(self, index):
        next_index = random.randint(0, len(self) - 1)
        return self.get(next_index)

    def _check_image(self, x, pixels=6):
        if x.size[0] <= pixels or x.size[1] <= pixels: return False
        else: return True

    def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT): 
        def _resize_ratio(img, ratio, fix_h=True):
            if ratio * self.img_w < self.img_h:
                if fix_h: trg_h = self.img_h
                else: trg_h = int(ratio * self.img_w)
                trg_w = self.img_w
            else: trg_h, trg_w = self.img_h, int(self.img_h / ratio)
            img = cv2.resize(img, (trg_w, trg_h))
            pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2
            top, bottom = math.ceil(pad_h), math.floor(pad_h)
            left, right = math.ceil(pad_w), math.floor(pad_w)
            img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType)
            return img

        if self.is_training: 
            if random.random() < 0.5:
                base, maxh, maxw = self.img_h, self.img_h, self.img_w
                h, w = random.randint(base, maxh), random.randint(base, maxw)
                return _resize_ratio(img, h/w)
            else: return _resize_ratio(img, img.shape[0] / img.shape[1])  # keep aspect ratio
        else:  return _resize_ratio(img, img.shape[0] / img.shape[1])  # keep aspect ratio

    def resize(self, img):
        if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE)
        else: return cv2.resize(img, (self.img_w, self.img_h))

    def get(self, idx):
        with self.env.begin(write=False) as txn:
            image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}'
            try:
                label = str(txn.get(label_key.encode()), 'utf-8')  # label
                label = label.replace(' ', '').replace('\u3000', '')
                label = strQ2B(label)
                # label = re.sub('[^0-9a-zA-Z]+', '', label)
                if self.check_length and self.max_length > 0:
                    if len(label) > self.max_length or len(label) <= 0:
                        #logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}')
                        return self._next_image(idx)
                label = label[:self.max_length]

                imgbuf = txn.get(image_key.encode())  # image
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
                    image = PIL.Image.open(buf).convert(self.convert_mode)
                if self.is_training and not self._check_image(image):
                    #logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}')
                    return self._next_image(idx)
            except:
                import traceback
                traceback.print_exc()
                logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}')
                return self._next_image(idx)
            return image, label, idx

    def _process_training(self, image):
        if self.data_aug: image = self.augment_tfs(image)
        image = self.resize(np.array(image))
        return image

    def _process_test(self, image):
        return self.resize(np.array(image)) # TODO:move is_training to here

    def __getitem__(self, idx):
        image, text, idx_new = self.get(idx)
        if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.'

        if self.is_training: image = self._process_training(image)
        else: image = self._process_test(image)
        if self.return_raw: return image, text
        image = self.totensor(image)

        length = tensor(len(text) + 1).to(dtype=torch.long)  # one for end token
        strict = False if self.is_training else True
        label = self.charset.get_labels(text, case_sensitive=self.case_sensitive, strict=strict)
        if label is None:
            logging.warning(f'Not found in charset. Skip this text.: {self.name}, {idx}, {text}, {len(text)}')
            next_idx = random.randint(0, len(self) - 1)
            return self[next_idx]

        label = tensor(label).to(dtype=torch.long)
        if self.one_hot_y: label = onehot(label, self.charset.num_classes)

        if self.return_idx: y = [label, length, idx_new]
        else: y = [label, length]
        return image, y

class TextDataset(Dataset):
    def __init__(self,
                 path:PathOrStr, 
                 delimiter:str='\t',
                 max_length:int=25, 
                 charset_path:str='data/charset_36.txt', 
                 case_sensitive=False, 
                 one_hot_x=True,
                 one_hot_y=True,
                 is_training=True,
                 smooth_label=False,
                 smooth_factor=0.2,
                 use_sm=False,
                 **kwargs):
        self.path = Path(path)
        self.max_length = max_length
        self.case_sensitive, self.use_sm = case_sensitive, use_sm
        self.smooth_factor, self.smooth_label = smooth_factor, smooth_label
        self.charset = CharsetMapper(charset_path, max_length=max_length+1)
        self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training
        if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset)

        dtype = {'inp': str, 'gt': str}
        self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False)
        self.inp_col, self.gt_col = 0, 1

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        text_x = self.df.iloc[idx, self.inp_col]
        text_x = text_x.replace(' ', '')
        text_x = strQ2B(text_x)
        text_x = text_x[:self.max_length]

        # text_x = re.sub('[^0-9a-zA-Z]+', '', text_x)
        # if not self.case_sensitive: text_x = text_x.lower()
        if self.is_training and self.use_sm: text_x = self.sm(text_x)

        length_x = tensor(len(text_x) + 1).to(dtype=torch.long)  # one for end token
        strict = False if self.is_training else True
        label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive, strict=strict)
        if label_x is None:
            next_idx = random.randint(0, len(self) - 1)
            label_x = self[next_idx]

        label_x = tensor(label_x)
        if self.one_hot_x:
            label_x = onehot(label_x, self.charset.num_classes)
            if self.is_training and self.smooth_label: 
                label_x = torch.stack([self.prob_smooth_label(l) for l in label_x])
        x =  [label_x, length_x]

        # text_y = self.df.iloc[idx, self.gt_col]
        # text_y = text_y.replace(' ', '')
        # text_y = strQ2B(text_y)
        # text_y = text_y[:self.max_length]

        # # text_y = re.sub('[^0-9a-zA-Z]+', '', text_y)
        # # if not self.case_sensitive: text_y = text_y.lower()
        # length_y = tensor(len(text_y) + 1).to(dtype=torch.long)  # one for end token
        # label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive)
        # label_y = tensor(label_y)
        # if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes)
        # y = [label_y, length_y]

        return x, x

    def prob_smooth_label(self, one_hot):
        one_hot = one_hot.float()
        delta = torch.rand([]) * self.smooth_factor
        num_classes = len(one_hot)
        noise = torch.rand(num_classes)
        noise = noise / noise.sum() * delta
        one_hot = one_hot * (1 - delta) + noise
        return one_hot

你好,你说的进行修改部分跟你网上给的源码好像是一致的,请问能具体进行修改吗?

Gavin-zsr commented 2 years ago

您好,训练中文模型得在该仓库的基础上做一些额外的适配,暂时还没时间对比中文版本具体需要那些地方的修改,除了字典以外,至少有数据的处理部分需要修改,dataset.py的修改可参见如下

def strQ2B(ustring):
    ss = []
    for s in ustring:
        rstring = ""
        for uchar in s:
            inside_code = ord(uchar)
            if inside_code == 12288: 
                inside_code = 32
            elif (inside_code >= 65281 and inside_code <= 65374): 
                inside_code -= 65248
            rstring += chr(inside_code)
        ss.append(rstring)
    return ''.join(ss)

class ImageDataset(Dataset):
    "`ImageDataset` read data from LMDB database."

    def __init__(self,
                 path:PathOrStr,
                 is_training:bool=True,
                 img_h:int=32,
                 img_w:int=100,
                 max_length:int=25,
                 check_length:bool=True,
                 case_sensitive:bool=False,
                 charset_path:str='data/charset_36.txt',
                 convert_mode:str='RGB',
                 data_aug:bool=True,
                 deteriorate_ratio:float=0.,
                 multiscales:bool=True,
                 one_hot_y:bool=True,
                 return_idx:bool=False,
                 return_raw:bool=False,
                 **kwargs):
        self.path, self.name = Path(path), Path(path).name
        assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory."
        self.convert_mode, self.check_length = convert_mode, check_length
        self.img_h, self.img_w = img_h, img_w
        self.max_length, self.one_hot_y = max_length, one_hot_y
        self.return_idx, self.return_raw = return_idx, return_raw
        self.case_sensitive, self.is_training = case_sensitive, is_training
        self.data_aug, self.multiscales = data_aug, multiscales
        self.charset = CharsetMapper(charset_path, max_length=max_length+1)
        self.c = self.charset.num_classes

        self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
        assert self.env, f'Cannot open LMDB dataset from {path}.'
        with self.env.begin(write=False) as txn:
            self.length = int(txn.get('num-samples'.encode()))

        if self.is_training and self.data_aug:
            self.augment_tfs = transforms.Compose([
                CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
                CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
                CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25)
            ])
        self.totensor = transforms.ToTensor()

    def __len__(self): return self.length

    def _next_image(self, index):
        next_index = random.randint(0, len(self) - 1)
        return self.get(next_index)

    def _check_image(self, x, pixels=6):
        if x.size[0] <= pixels or x.size[1] <= pixels: return False
        else: return True

    def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT): 
        def _resize_ratio(img, ratio, fix_h=True):
            if ratio * self.img_w < self.img_h:
                if fix_h: trg_h = self.img_h
                else: trg_h = int(ratio * self.img_w)
                trg_w = self.img_w
            else: trg_h, trg_w = self.img_h, int(self.img_h / ratio)
            img = cv2.resize(img, (trg_w, trg_h))
            pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2
            top, bottom = math.ceil(pad_h), math.floor(pad_h)
            left, right = math.ceil(pad_w), math.floor(pad_w)
            img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType)
            return img

        if self.is_training: 
            if random.random() < 0.5:
                base, maxh, maxw = self.img_h, self.img_h, self.img_w
                h, w = random.randint(base, maxh), random.randint(base, maxw)
                return _resize_ratio(img, h/w)
            else: return _resize_ratio(img, img.shape[0] / img.shape[1])  # keep aspect ratio
        else:  return _resize_ratio(img, img.shape[0] / img.shape[1])  # keep aspect ratio

    def resize(self, img):
        if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE)
        else: return cv2.resize(img, (self.img_w, self.img_h))

    def get(self, idx):
        with self.env.begin(write=False) as txn:
            image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}'
            try:
                label = str(txn.get(label_key.encode()), 'utf-8')  # label
                label = label.replace(' ', '').replace('\u3000', '')
                label = strQ2B(label)
                # label = re.sub('[^0-9a-zA-Z]+', '', label)
                if self.check_length and self.max_length > 0:
                    if len(label) > self.max_length or len(label) <= 0:
                        #logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}')
                        return self._next_image(idx)
                label = label[:self.max_length]

                imgbuf = txn.get(image_key.encode())  # image
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
                    image = PIL.Image.open(buf).convert(self.convert_mode)
                if self.is_training and not self._check_image(image):
                    #logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}')
                    return self._next_image(idx)
            except:
                import traceback
                traceback.print_exc()
                logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}')
                return self._next_image(idx)
            return image, label, idx

    def _process_training(self, image):
        if self.data_aug: image = self.augment_tfs(image)
        image = self.resize(np.array(image))
        return image

    def _process_test(self, image):
        return self.resize(np.array(image)) # TODO:move is_training to here

    def __getitem__(self, idx):
        image, text, idx_new = self.get(idx)
        if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.'

        if self.is_training: image = self._process_training(image)
        else: image = self._process_test(image)
        if self.return_raw: return image, text
        image = self.totensor(image)

        length = tensor(len(text) + 1).to(dtype=torch.long)  # one for end token
        strict = False if self.is_training else True
        label = self.charset.get_labels(text, case_sensitive=self.case_sensitive, strict=strict)
        if label is None:
            logging.warning(f'Not found in charset. Skip this text.: {self.name}, {idx}, {text}, {len(text)}')
            next_idx = random.randint(0, len(self) - 1)
            return self[next_idx]

        label = tensor(label).to(dtype=torch.long)
        if self.one_hot_y: label = onehot(label, self.charset.num_classes)

        if self.return_idx: y = [label, length, idx_new]
        else: y = [label, length]
        return image, y

class TextDataset(Dataset):
    def __init__(self,
                 path:PathOrStr, 
                 delimiter:str='\t',
                 max_length:int=25, 
                 charset_path:str='data/charset_36.txt', 
                 case_sensitive=False, 
                 one_hot_x=True,
                 one_hot_y=True,
                 is_training=True,
                 smooth_label=False,
                 smooth_factor=0.2,
                 use_sm=False,
                 **kwargs):
        self.path = Path(path)
        self.max_length = max_length
        self.case_sensitive, self.use_sm = case_sensitive, use_sm
        self.smooth_factor, self.smooth_label = smooth_factor, smooth_label
        self.charset = CharsetMapper(charset_path, max_length=max_length+1)
        self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training
        if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset)

        dtype = {'inp': str, 'gt': str}
        self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False)
        self.inp_col, self.gt_col = 0, 1

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        text_x = self.df.iloc[idx, self.inp_col]
        text_x = text_x.replace(' ', '')
        text_x = strQ2B(text_x)
        text_x = text_x[:self.max_length]

        # text_x = re.sub('[^0-9a-zA-Z]+', '', text_x)
        # if not self.case_sensitive: text_x = text_x.lower()
        if self.is_training and self.use_sm: text_x = self.sm(text_x)

        length_x = tensor(len(text_x) + 1).to(dtype=torch.long)  # one for end token
        strict = False if self.is_training else True
        label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive, strict=strict)
        if label_x is None:
            next_idx = random.randint(0, len(self) - 1)
            label_x = self[next_idx]

        label_x = tensor(label_x)
        if self.one_hot_x:
            label_x = onehot(label_x, self.charset.num_classes)
            if self.is_training and self.smooth_label: 
                label_x = torch.stack([self.prob_smooth_label(l) for l in label_x])
        x =  [label_x, length_x]

        # text_y = self.df.iloc[idx, self.gt_col]
        # text_y = text_y.replace(' ', '')
        # text_y = strQ2B(text_y)
        # text_y = text_y[:self.max_length]

        # # text_y = re.sub('[^0-9a-zA-Z]+', '', text_y)
        # # if not self.case_sensitive: text_y = text_y.lower()
        # length_y = tensor(len(text_y) + 1).to(dtype=torch.long)  # one for end token
        # label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive)
        # label_y = tensor(label_y)
        # if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes)
        # y = [label_y, length_y]

        return x, x

    def prob_smooth_label(self, one_hot):
        one_hot = one_hot.float()
        delta = torch.rand([]) * self.smooth_factor
        num_classes = len(one_hot)
        noise = torch.rand(num_classes)
        noise = noise / noise.sum() * delta
        one_hot = one_hot * (1 - delta) + noise
        return one_hot

请问strict这个参数在get_labels函数中表示什么意思呢,get_labels函数并没有提供这个参数,应该如何修改这个函数呢

strict = False if self.is_training else True
label = self.charset.get_labels(text, case_sensitive=self.case_sensitive, strict=strict)