lllyasviel / ControlNet

Let us control diffusion models!
Apache License 2.0
28.94k stars 2.61k forks source link

self-trained ControlNet is slower than standard #683

Open Liuqh12 opened 3 weeks ago

Liuqh12 commented 3 weeks ago

I train myself ControlNet according to tutorial_train.py.

After training, I got my_cn.ckpt, size about 8G.

my_cn.ckpt can load, run and get expected results by gradio_scribble2image.py , just update:

model.load_state_dict(load_state_dict('./models/my_cn.ckpt', location='cuda'))

However, during inference, I found my_cn is several times slower than yours huggingface.

I print state_dict in my_cn.ckpt and control_sd15_scribble.pth, both are torch.float32.

I test ControlNet alone, code as follow:

from share import *
import cv2
import torch
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from tqdm import tqdm
from torchinfo import summary

model = create_model('./models/cldm_v15.yaml').cpu()

# 317 ms
# 1445.12M-params
sketch_ckpt_path = './models/control_sd15_scribble.pth'

# 1545 ms
# 1445.12M-params
# sketch_ckpt_path = './models/my_cn.ckpt'

model.load_state_dict(load_state_dict(sketch_ckpt_path, location='cuda'))

model = model.cuda()
control_net = model.control_model

x, hint, timesteps, context = torch.rand((1,4,64,64)).to('cuda'), torch.rand((1,3,512,512)).to('cuda'), torch.rand((1)).to('cuda'), torch.rand((1,77,768)).to('cuda')

# print model information: https://github.com/TylerYep/torchinfo
summary(control_net, input_data=[x, hint, timesteps, context])

epoch = 50
e_sum = 0.00
for i in tqdm(range(0, epoch)):
    begin = cv2.getTickCount()
    control_net(x, hint, timesteps, context)
    end = cv2.getTickCount()
    # to ms
    e_sum += (end - begin) / cv2.getTickFrequency() * 1000.0
print(e_sum / epoch)
print("Done!")

I think I must have missed some details, looking forward to your suggestions.