Open Liuqh12 opened 3 weeks ago
I train myself ControlNet according to tutorial_train.py.
After training, I got my_cn.ckpt, size about 8G.
my_cn.ckpt
my_cn.ckpt can load, run and get expected results by gradio_scribble2image.py , just update:
model.load_state_dict(load_state_dict('./models/my_cn.ckpt', location='cuda'))
However, during inference, I found my_cn is several times slower than yours huggingface.
my_cn
huggingface
I print state_dict in my_cn.ckpt and control_sd15_scribble.pth, both are torch.float32.
state_dict
control_sd15_scribble.pth
torch.float32
I test ControlNet alone, code as follow:
from share import * import cv2 import torch from cldm.model import create_model, load_state_dict from cldm.ddim_hacked import DDIMSampler from tqdm import tqdm from torchinfo import summary model = create_model('./models/cldm_v15.yaml').cpu() # 317 ms # 1445.12M-params sketch_ckpt_path = './models/control_sd15_scribble.pth' # 1545 ms # 1445.12M-params # sketch_ckpt_path = './models/my_cn.ckpt' model.load_state_dict(load_state_dict(sketch_ckpt_path, location='cuda')) model = model.cuda() control_net = model.control_model x, hint, timesteps, context = torch.rand((1,4,64,64)).to('cuda'), torch.rand((1,3,512,512)).to('cuda'), torch.rand((1)).to('cuda'), torch.rand((1,77,768)).to('cuda') # print model information: https://github.com/TylerYep/torchinfo summary(control_net, input_data=[x, hint, timesteps, context]) epoch = 50 e_sum = 0.00 for i in tqdm(range(0, epoch)): begin = cv2.getTickCount() control_net(x, hint, timesteps, context) end = cv2.getTickCount() # to ms e_sum += (end - begin) / cv2.getTickFrequency() * 1000.0 print(e_sum / epoch) print("Done!")
I think I must have missed some details, looking forward to your suggestions.
I train myself ControlNet according to tutorial_train.py.
After training, I got
my_cn.ckpt
, size about 8G.my_cn.ckpt
can load, run and get expected results by gradio_scribble2image.py , just update:However, during inference, I found
my_cn
is several times slower than yourshuggingface
.I print
state_dict
inmy_cn.ckpt
andcontrol_sd15_scribble.pth
, both aretorch.float32
.I test ControlNet alone, code as follow:
I think I must have missed some details, looking forward to your suggestions.