AUTOMATIC1111 / stable-diffusion-webui-tensorrt

MIT License
311 stars 20 forks source link

SDXL Support #58

Open CyberTimon opened 1 year ago

CyberTimon commented 1 year ago

Hello

Is SDXL support planned, as SDXL is slow on most computers?

Kind regards, Timon Käch

zz2222222222222 commented 1 year ago

Hello

Is SDXL support planned, as SDXL is slow on most computers?

Kind regards, Timon Käch

already try can, but need modify code

speed from 6.67it/s up to 12.10 it/s w 960:h:1024 step 21


1. export to onnx the new method

`import os

from modules import sd_hijack, sd_unet from modules import shared, devices import torch

def export_current_unet_to_onnx(filename, opset_version=17): x = torch.randn(1, 4, 16, 16).to(devices.device, devices.dtype) timesteps = torch.zeros((1,)).to(devices.device, devices.dtype) + 500 context = torch.randn(1, 77, 2048).to(devices.device, devices.dtype) y = torch.randn(1, 2816).to(devices.device, devices.dtype) def disable_checkpoint(self): if getattr(self, 'use_checkpoint', False) == True: self.use_checkpoint = False if getattr(self, 'checkpoint', False) == True: self.checkpoint = False

shared.sd_model.model.diffusion_model.apply(disable_checkpoint)

sd_unet.apply_unet("None")
sd_hijack.model_hijack.apply_optimizations('None')

os.makedirs(os.path.dirname(filename), exist_ok=True)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
shared.sd_model.model.diffusion_model = shared.sd_model.model.diffusion_model.to(device)

with devices.autocast():
    torch.onnx.export(
        shared.sd_model.model.diffusion_model,
        (x, timesteps, context,y),
        filename,
        export_params=True,
        opset_version=opset_version,
        do_constant_folding=True,
        input_names=['x', 'timesteps', 'context','y'],
        output_names=['output'],
        dynamic_axes={
            'x': {0: 'batch_size', 2: 'height', 3: 'width'},
            'timesteps': {0: 'batch_size'},
            'context': {0: 'batch_size', 1: 'sequence_length'},
            'y':{0:'batch_size'},
            'output': {0: 'batch_size'},
        },
    )

sd_hijack.model_hijack.apply_optimizations()
sd_unet.apply_unet()

`

3.hijack the UNetModel_forwardy,

/modules/sd_hijack.py

` ... if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'): ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward

    ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward

    if not hasattr(sgm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
        sgm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = sgm.modules.diffusionmodules.openaimodel.UNetModel.forward

    sgm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forwardy

def undo_hijack(self, m):
    if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
        m.cond_stage_model = m.cond_stage_model.wrapped

    elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
        m.cond_stage_model = m.cond_stage_model.wrapped

        model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
        if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
            model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
    elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
        m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
        m.cond_stage_model = m.cond_stage_model.wrapped

    undo_optimizations()
    undo_weighted_forward(m)

    self.apply_circular(False)
    self.layers = None
    self.clip = None

    ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
    sgm.modules.diffusionmodules.openaimodel.UNetModel.forward = sgm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui

... `

3. modules/sd_unet.py

` ... class SdUnet(torch.nn.Module): def forward(self, x, timesteps, context, *args, **kwargs): raise NotImplementedError()

def activate(self):
    pass

def deactivate(self):
    pass

def UNetModel_forward(self, x, timesteps=None, context=None, *args, *kwargs): if current_unet is not None: return current_unet.forward(x, timesteps, context, args, **kwargs)

return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)

def UNetModel_forwardy(self, x, timesteps=None, context=None, y=None, kwargs): if current_unet is not None: return current_unet.forward(x, timesteps, context, y, kwargs)

return sgm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context,y, **kwargs)

here can use same method

... `

4. extensions/stable-diffusion-webui-tensorrt/scripts/trt.py

` def forward(self, x, timesteps, context,*args, **kwargs): a,b,c,d=x.shape

    #print(x.shape,timesteps.shape,context.shape)

    if a==1:
        self.infer({"x": x, "timesteps": timesteps, "context": context})
        #print(self)

        return self.buffers["output"].to(dtype=x.dtype, device=devices.device)
    else:
        images=[]
        for i in range(a):
            with contextlib.suppress(Exception):
                s = x[i].unsqueeze(0)
                t = timesteps[i].unsqueeze(0)
                c = context[i].unsqueeze(0)
                if args is not None and args.__len__()!=0:
                    y = args[0][i].unsqueeze(0)
                    self.infer({"x": s, "timesteps": t, "context": c,"y":y})
                #print(self)
                else:
                    self.infer({"x": s, "timesteps": t, "context": c})

                tmp_img= self.buffers["output"].to(dtype=x.dtype, device=devices.device)
                new_var = tmp_img
                images.append(new_var)
        return torch.cat(images, dim=0)

`

5. and for found 2 device problem

you need one by one find out it add model.to(devices.device) or easy way use model.cuda() // have maybe 3-4 place need modify

7. export onnx to trt my command

"{full_path}/trtexec" --onnx="{full_path}/models/Unet-onnx/ttt.onnx" --saveEngine="{full_path}/models/Unet-trt/ttt.trt" --minShapes=x:1x4x64x64,context:1x77x2048,timesteps:1 --maxShapes=x:1x4x128x120,context:1x77x2048,timesteps:1 --fp16

CyberTimon commented 1 year ago

Hey, thank you so much for the fast answer. Will try it out soon. Is 1024x1024 not possible? Only 960x1024?

zz2222222222222 commented 1 year ago

Hey, thank you so much for the fast answer. Will try it out soon. Is 1024x1024 not possible? Only 960x1024?

cant sure ,maxShapes=x:1x4x128x120 cant over this size if use maxShapes=x:1x4x128x128 the trtexec will popup the error