Closed keal8180 closed 1 year ago
您好,根据这段代码我们可知用作风格参考的样本是储存在pickle文件中的。因此,您自己写的字符是应该存成pickle
文件格式的。具体来说,先使用cv2
读取png的文件,再使用pickle
持久化存储即可。贴一段转换格式的示例代码如下:
character_list = glob.glob(os.path.join(your_path, '*.jpg'))
characters = []
# read style characters from the specific writer
for img_path in character_list:
char_img = cv2.imread(img_path, 0)
characters.append({'img':char_img})
# write cache into pickle
target_path = os.path.join(your_name+'.pkl')
with open(target_path, 'wb') as f:
pickle.dump(characters, f)
非常感谢您的帮助
请问 在测试自己图片的过程中 除了需要将自己的图片转化为pkl格式 放入test_style_samples文件夹中 并修writer_dict 之外还需要修改什么呢 我还需要对test中的lmdb文件进行修改吗
可以先运行下看下效果,应该不需要修改lmdb文件了
非常感谢
可以先运行下看下效果,应该不需要修改lmdb文件了 ScriptDataset中的def getitem(self, index)会从对应的lmdb中去取fname,如果write_dict中的文件名与mdb的不匹配,那么writer_id = self.writer_dict[fname]给writer_id赋值时候就会报错。
从lmdb中获取fname:
with self.lmdb.begin(write=False) as txn: data = pickle.loads(txn.get(str(index).encode('utf-8'))) tag_char, coords, fname = data['tag_char'], data['coordinates'], data['fname']
从writer_dict中根据key查找对应的数据:
self.writer_dict = self.all_writer['test_writer'] writer_id = self.writer_dict[fname]
例如当改写write_dict中“test_writer:{006.pot:1}”时,data.mdb中会读取到其他的fname
当反序列化读取Chinese和English的data.mdb数据时,发现两者在构造上有些区别。能提供更多关于mdb的信息吗,谢谢。
如果只是想要生成自己风格的文字,把pkl改好就行啦。lmdb不需要理会的,那个用不到
我是不是要告诉他我写的字是哪个字呀,如何告诉呢
我是不是要告诉他我写的字是哪个字呀,如何告诉呢
用户不需要额外处理了。我们的代码中提供了默认的3755个字符内容,运行代码就会生成不同风格的3755个字符。
改好了pkl放到test_style_sample还是没法生成自己的字体,可以出一个具体的教程吗?谢谢
改好了pkl放到test_style_sample还是没法生成自己的字体,可以出一个具体的教程吗?谢谢
感谢关注~近期我会出一个生成自己字体的教程放在置顶issue上
改好了pkl放到test_style_sample还是没法生成自己的字体,可以出一个具体的教程吗?谢谢
可以直接用这个脚本,确保样式图片是64*64大小的jpg文件即可,输出目录自行改一下。
python predict.py --pretrained_model model_zoo/checkpoint-iter199999.pth --style_sample_size 20
import argparse
import os
from parse_config import cfg, cfg_from_file, assert_and_infer_cfg
import torch
import glob
import random
import pickle
import time
from PIL import Image
import numpy as np
from models.model import SDT_Generator
from utils.util import writeCache, dxdynp_to_list, coords_render
device = "cpu"
if torch.cuda.is_available():
device = torch.device('cuda')
print('Using GPU for training')
else:
device = torch.device('cpu')
print('Using CPU for training')
def load_model():
"""
加载模型
:return:
"""
model = SDT_Generator(num_encoder_layers=cfg.MODEL.ENCODER_LAYERS,
num_head_layers=cfg.MODEL.NUM_HEAD_LAYERS,
wri_dec_layers=cfg.MODEL.WRI_DEC_LAYERS,
gly_dec_layers=cfg.MODEL.GLY_DEC_LAYERS).to(device)
if len(opt.pretrained_model) > 0:
model_weight = torch.load(opt.pretrained_model, map_location=device)
model.load_state_dict(model_weight)
print('load pretrained model from {}'.format(opt.pretrained_model))
else:
raise IOError('input the correct checkpoint path')
model.eval()
return model
def load_styles(sample_dir, pic_format=".jpg", sample_count=15):
"""
:param sample_dir: 样式图片所在目录
:param pic_format: 样式图片格式
:param sample_count: 样式图片数量
:return:
"""
# todo: 以后增加label
character_list = glob.glob(os.path.join(sample_dir, '*' + pic_format))
all_character_images = []
for character in character_list:
image = Image.open(character).convert("L")
image_array = np.array(image)
all_character_images.append(image_array)
# 加载样例图
sampled_images = []
random_indexs = random.sample(range(len(all_character_images)), sample_count)
for idx in random_indexs:
tmp_img = all_character_images[idx]
tmp_img = tmp_img / 255.
sampled_images.append(tmp_img)
sampled_images = np.expand_dims(sampled_images, 1) # [N, C, H, W], C=1
return sampled_images
def infer(model, char, style_images):
"""
推理并保存结果为jpg
:param model:
:param char:
:param style_images:
:return:
"""
with torch.no_grad():
# 加载Content底图
content = pickle.load(open("./data/CASIA_CHINESE/Chinese_content.pkl", 'rb'))
char_img = content[char] # content
char_img = char_img / 255.
style_images_tensor = torch.Tensor(style_images)
char_img_tensor = torch.Tensor(char_img)
char_img_tensor = char_img_tensor.unsqueeze(0) # [C, H, W]
batched_style_images_tensor = torch.unsqueeze(style_images_tensor, dim=0)
batched_char_img_tensor = torch.unsqueeze(char_img_tensor, dim=0)
# 预测
prediction = model.inference(batched_style_images_tensor, batched_char_img_tensor, 120)
# 绘制
batch_size = 1
SOS = torch.tensor(batch_size * [[0, 0, 1, 0, 0]]).unsqueeze(1).to(prediction)
preds = torch.cat((SOS, prediction), 1) # add the SOS token like GT
preds = preds.detach().cpu().numpy()
for i, pred in enumerate(preds):
"""intends to blur the boundaries of each sample to fit the actual using situations,
as suggested in 'Deep imitator: Handwriting calligraphy imitation via deep attention networks'"""
sk_pil = coords_render(preds[i], split=True, width=64, height=64, thickness=1, board=0)
save_path = os.path.join(opt.save_dir, 'predict', char + '.jpg')
try:
sk_pil.save(save_path)
except Exception as e:
print('error. %s, %s' % (save_path, char))
print(e)
def main(opt):
cfg_from_file(opt.cfg_file)
assert_and_infer_cfg()
model = load_model()
style_images = load_styles("./data/CASIA_CHINESE/styles/manual/")
start = time.time()
for char in "我有一些文字想保存你想要福报么":
infer(model, char, style_images)
print("infer cost time: ", time.time() - start)
# python predict.py --pretrained_model checkpoint_path --store_type online --sample_size 500 --dir Generated/Chinese
if __name__ == '__main__':
"""Parse input arguments"""
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', dest='cfg_file', default='configs/CHINESE_CASIA.yml',
help='Config file for training (and optionally testing)')
parser.add_argument('--dir', dest='save_dir', default='Generated/Chinese',
help='target dir for storing the generated characters')
parser.add_argument('--pretrained_model', dest='pretrained_model', default='', required=True,
help='continue train model')
parser.add_argument('--style_sample_size', dest='style_sample_size', default='15', required=False,
help='when predicting, the real sample size')
opt = parser.parse_args()
main(opt)
在尝试使用自己 的图片 进行测试时 我发现 测试过程中的输入文件时pkl 格式, 请问 在使png 图片时 是以什么格式 转化为pkl文件的呢