dailenson / SDT

This repository is the official implementation of Disentangling Writer and Character Styles for Handwriting Generation (CVPR 2023)
MIT License
968 stars 82 forks source link

预处理 自己的图片 #18

Closed keal8180 closed 1 year ago

keal8180 commented 1 year ago

在尝试使用自己 的图片 进行测试时 我发现 测试过程中的输入文件时pkl 格式, 请问 在使png 图片时 是以什么格式 转化为pkl文件的呢

dailenson commented 1 year ago

您好,根据这段代码我们可知用作风格参考的样本是储存在pickle文件中的。因此,您自己写的字符是应该存成pickle文件格式的。具体来说,先使用cv2读取png的文件,再使用pickle持久化存储即可。贴一段转换格式的示例代码如下:

character_list = glob.glob(os.path.join(your_path, '*.jpg')) 
characters = []
# read style characters from the specific writer
for img_path in character_list: 
  char_img = cv2.imread(img_path, 0)
  characters.append({'img':char_img})
  # write cache into pickle
  target_path = os.path.join(your_name+'.pkl')
  with open(target_path, 'wb') as f:
    pickle.dump(characters, f)
keal8180 commented 1 year ago

非常感谢您的帮助

keal8180 commented 1 year ago

请问 在测试自己图片的过程中 除了需要将自己的图片转化为pkl格式 放入test_style_samples文件夹中 并修writer_dict 之外还需要修改什么呢 我还需要对test中的lmdb文件进行修改吗

dailenson commented 1 year ago

可以先运行下看下效果,应该不需要修改lmdb文件了

keal8180 commented 1 year ago

非常感谢

wang-shankun commented 10 months ago

可以先运行下看下效果,应该不需要修改lmdb文件了 ScriptDataset中的def getitem(self, index)会从对应的lmdb中去取fname,如果write_dict中的文件名与mdb的不匹配,那么writer_id = self.writer_dict[fname]给writer_id赋值时候就会报错。

从lmdb中获取fname: with self.lmdb.begin(write=False) as txn: data = pickle.loads(txn.get(str(index).encode('utf-8'))) tag_char, coords, fname = data['tag_char'], data['coordinates'], data['fname'] 从writer_dict中根据key查找对应的数据: self.writer_dict = self.all_writer['test_writer'] writer_id = self.writer_dict[fname] 例如当改写write_dict中“test_writer:{006.pot:1}”时,data.mdb中会读取到其他的fname

当反序列化读取Chinese和English的data.mdb数据时,发现两者在构造上有些区别。能提供更多关于mdb的信息吗,谢谢。

dailenson commented 9 months ago

如果只是想要生成自己风格的文字,把pkl改好就行啦。lmdb不需要理会的,那个用不到

f1cey commented 8 months ago

我是不是要告诉他我写的字是哪个字呀,如何告诉呢

dailenson commented 8 months ago

我是不是要告诉他我写的字是哪个字呀,如何告诉呢

用户不需要额外处理了。我们的代码中提供了默认的3755个字符内容,运行代码就会生成不同风格的3755个字符。

Cloud9High5 commented 8 months ago

改好了pkl放到test_style_sample还是没法生成自己的字体,可以出一个具体的教程吗?谢谢

dailenson commented 8 months ago

改好了pkl放到test_style_sample还是没法生成自己的字体,可以出一个具体的教程吗?谢谢

感谢关注~近期我会出一个生成自己字体的教程放在置顶issue上

cactusgame commented 8 months ago

改好了pkl放到test_style_sample还是没法生成自己的字体,可以出一个具体的教程吗?谢谢

可以直接用这个脚本,确保样式图片是64*64大小的jpg文件即可,输出目录自行改一下。

python predict.py --pretrained_model model_zoo/checkpoint-iter199999.pth --style_sample_size 20
import argparse
import os
from parse_config import cfg, cfg_from_file, assert_and_infer_cfg
import torch
import glob
import random
import pickle
import time
from PIL import Image
import numpy as np
from models.model import SDT_Generator
from utils.util import writeCache, dxdynp_to_list, coords_render

device = "cpu"
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Using GPU for training')
else:
    device = torch.device('cpu')
    print('Using CPU for training')

def load_model():
    """
     加载模型
    :return:
    """
    model = SDT_Generator(num_encoder_layers=cfg.MODEL.ENCODER_LAYERS,
                          num_head_layers=cfg.MODEL.NUM_HEAD_LAYERS,
                          wri_dec_layers=cfg.MODEL.WRI_DEC_LAYERS,
                          gly_dec_layers=cfg.MODEL.GLY_DEC_LAYERS).to(device)
    if len(opt.pretrained_model) > 0:
        model_weight = torch.load(opt.pretrained_model, map_location=device)
        model.load_state_dict(model_weight)
        print('load pretrained model from {}'.format(opt.pretrained_model))
    else:
        raise IOError('input the correct checkpoint path')
    model.eval()
    return model

def load_styles(sample_dir, pic_format=".jpg", sample_count=15):
    """
    :param sample_dir: 样式图片所在目录
    :param pic_format: 样式图片格式
    :param sample_count: 样式图片数量
    :return:
    """
    # todo: 以后增加label
    character_list = glob.glob(os.path.join(sample_dir, '*' + pic_format))
    all_character_images = []
    for character in character_list:
        image = Image.open(character).convert("L")
        image_array = np.array(image)
        all_character_images.append(image_array)

    # 加载样例图
    sampled_images = []
    random_indexs = random.sample(range(len(all_character_images)), sample_count)
    for idx in random_indexs:
        tmp_img = all_character_images[idx]
        tmp_img = tmp_img / 255.
        sampled_images.append(tmp_img)

    sampled_images = np.expand_dims(sampled_images, 1)  # [N, C, H, W], C=1
    return sampled_images

def infer(model, char, style_images):
    """
    推理并保存结果为jpg
    :param model:
    :param char:
    :param style_images:
    :return:
    """
    with torch.no_grad():
        # 加载Content底图
        content = pickle.load(open("./data/CASIA_CHINESE/Chinese_content.pkl", 'rb'))
        char_img = content[char]  # content
        char_img = char_img / 255.

        style_images_tensor = torch.Tensor(style_images)
        char_img_tensor = torch.Tensor(char_img)
        char_img_tensor = char_img_tensor.unsqueeze(0)  # [C, H, W]

        batched_style_images_tensor = torch.unsqueeze(style_images_tensor, dim=0)
        batched_char_img_tensor = torch.unsqueeze(char_img_tensor, dim=0)

        # 预测
        prediction = model.inference(batched_style_images_tensor, batched_char_img_tensor, 120)

        # 绘制
        batch_size = 1
        SOS = torch.tensor(batch_size * [[0, 0, 1, 0, 0]]).unsqueeze(1).to(prediction)
        preds = torch.cat((SOS, prediction), 1)  # add the SOS token like GT
        preds = preds.detach().cpu().numpy()

        for i, pred in enumerate(preds):
            """intends to blur the boundaries of each sample to fit the actual using situations,
                as suggested in 'Deep imitator: Handwriting calligraphy imitation via deep attention networks'"""
            sk_pil = coords_render(preds[i], split=True, width=64, height=64, thickness=1, board=0)
            save_path = os.path.join(opt.save_dir, 'predict', char + '.jpg')
            try:
                sk_pil.save(save_path)
            except Exception as e:
                print('error. %s, %s' % (save_path, char))
                print(e)

def main(opt):
    cfg_from_file(opt.cfg_file)
    assert_and_infer_cfg()

    model = load_model()
    style_images = load_styles("./data/CASIA_CHINESE/styles/manual/")

    start = time.time()
    for char in "我有一些文字想保存你想要福报么":
        infer(model, char, style_images)
    print("infer cost time: ", time.time() - start)

# python predict.py --pretrained_model checkpoint_path --store_type online --sample_size 500 --dir Generated/Chinese
if __name__ == '__main__':
    """Parse input arguments"""
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', dest='cfg_file', default='configs/CHINESE_CASIA.yml',
                        help='Config file for training (and optionally testing)')
    parser.add_argument('--dir', dest='save_dir', default='Generated/Chinese',
                        help='target dir for storing the generated characters')
    parser.add_argument('--pretrained_model', dest='pretrained_model', default='', required=True,
                        help='continue train model')
    parser.add_argument('--style_sample_size', dest='style_sample_size', default='15', required=False,
                        help='when predicting, the real sample size')
    opt = parser.parse_args()
    main(opt)