cumulo-autumn / StreamDiffusion

StreamDiffusion: A Pipeline-Level Solution for Real-Time Interactive Generation
Apache License 2.0
9.48k stars 677 forks source link

realtime-text2image ControlNet #132

Open Liuqh12 opened 6 months ago

Liuqh12 commented 6 months ago

You can add ControlNet to StreamDiffusion as follow:

  1. uninstall streamdiffusion in your env:pip uninstall streamdiffusion
  2. prepare your ControlNet, such as:
    
    openpose_pre_train_path = r"D:\lqh12\a-sd-based-models\sd-controlnet-openpose"
    openpose = OpenposeDetector.from_pretrained(r'D:\lqh12\a-sd-based-models\lllyasviel-ControlNet')
    o_image = load_image(r"D:\lqh12\a-sd-based-models\sd-controlnet-openpose\images\pose.png")

controlnet = ControlNetModel.from_pretrained(openpose_pre_train_path, torch_dtype=torch.float16)

pipe_t = StableDiffusionControlNetPipeline.from_pretrained( r"D:\lqh12\a-sd-based-models\sdv15", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16 ).to("cuda")

3. controlnet.forward, such as:
```python
def get_a_b_control_net(a, b):
    image = pipe_t.prepare_image(    
            image=openpose(o_image),
            width=512,
            height=512,
            batch_size=1 * 1,
            num_images_per_prompt=1,
            device="cuda",
            dtype=controlnet.dtype,
            do_classifier_free_guidance=False,
            guess_mode=False,            
        )

    down_block_res_samples, mid_block_res_sample = controlnet(
        a,
        801,
        encoder_hidden_states=b,
        controlnet_cond=image,
        conditioning_scale=1.0,
        guess_mode=False,
        return_dict=False,
    )
    return down_block_res_samples, mid_block_res_sample
  1. use controlnet.forward result. modify file: pipeline.py : StreamDiffusion : unet_step, look as
    a, b = get_a_b_control_net(x_t_latent_plus_uc, self.prompt_embeds)
    model_pred = self.unet(
    x_t_latent_plus_uc,
    t_list,
    encoder_hidden_states=self.prompt_embeds,
    down_block_additional_residuals=a,
    mid_block_additional_residual=b,
    return_dict=False,
    )[0]
  2. reinstall streamdiffusion from your code: python setup.py develop easy_install streamdiffusion[tensorrt]
  3. enjoy it

thanks all. Great work StreamDiffusion.

Example: ControlNet input image:

pose

My prompt: chef in the kitchen

realtime-text2image-with-control-net result: chef with cn

realtime-text2image-no-control-net result: chef no pose-control-net

WyattAutomation commented 6 months ago

You sir are a baller, thank you so much for sharing this!

WyattAutomation commented 6 months ago

Hey there, @Liuqh12 thank you so much for sharing this.

After giving this an exhaustive try though, I can not get it working. No errors, just doesn't do anything, at least how I tried to use this.

Main question, if you don't want to read all of this -- how is the modified "unet_step" from pipeline.py here, supposed to use your "get_a_b_control_net" function, and can you share your full modifications to the repo?


more details:

Your get_a_b_control_net function needs pipe_t as well as other variables/objects declared in your first step. However, in your instructions you modify the unet_step method, directly in pipeline.py so that it calls get_a_b_controlnet.

As a hack to just try to get it working at all, I added get_a_b_control_net as a method to pipeline.py -> StreamDiffusion class, and then tried to pass everything it needs as args through the StreamDiffusionWrapper(...) instance created in my main script, to the StreamDiffusion(...) instance created in wrapper.py, and then as args through StreamDiffusion -> txt2img(..) and then as args again into predict_x0_batch(..), which finally passes everything as args through it's call to unet_step(...) that is needed inside of get_a_b_control_net (that I added as a method to the same class that unet_step is in).

I am setting up everything in my own python code (diffusion_generator_controlnet.py) like so -- the models in my script are publicly available ones from huggingface, I am not sure what models you used (diffusers could not find the ones named from your example on huggingface). I use lcm_lora and an LCM variant of DreamShaper here, in order to add denoising "strength" and "guidance" args to have control those params via the LCM scheduler, as StreamDiffusion doesn't have an apparent way to do this otherwise -- but here is my class that sets it up:


import os
import sys
import torch

from typing import Literal, Dict, Optional

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

from streamdiff_utils.wrapper import StreamDiffusionWrapper

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

from diffusers import LCMScheduler, StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.utils import load_image
from controlnet_aux import OpenposeDetector

class DiffusionGenerator:
    def __init__(self, init_prompt="", guidance_scale=1.6, strength=1.7):
        self.model_id_or_path: str = "SimianLuo/LCM_Dreamshaper_v7"
        self.taesd_model: str = "madebyollin/taesd"
        self.lora_dict: Optional[Dict[str, float]] = None
        self.prompt: str = init_prompt
        self.negative_prompt: str = "low quality, bad quality, blurry, low resolution"
        self.width: int = 512
        self.height: int = 512
        self.acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt"
        self.use_denoising_batch: bool = True
        self.guidance_scale = float(guidance_scale)
        self.strength = float(strength)
        self.seed: int = 123456
        self.delta: float = 0.5

        self.openpose_pre_train_path = "lllyasviel/control_v11p_sd15_openpose"
        self.openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
        self.o_image = load_image("opose_test.png")
        self.controlnet = ControlNetModel.from_pretrained(self.openpose_pre_train_path, torch_dtype=torch.float16)
        self.pipe_t = StableDiffusionControlNetPipeline.from_pretrained(
            self.model_id_or_path, controlnet=self.controlnet, safety_checker=None, torch_dtype=torch.float16
        ).to("cuda")

        self.stream = StreamDiffusionWrapper(
            model_id_or_path=self.model_id_or_path,
            pipe_t=self.pipe_t,
            controlnet=self.controlnet, 
            o_image=self.o_image, 
            openpose=self.openpose,
            mode="txt2img",
            use_tiny_vae=self.taesd_model,
            t_index_list=[35, 45],
            frame_buffer_size=1,
            width=self.width,
            height=self.height,
            use_lcm_lora=True,
            output_type="pil",
            warmup=10,
            acceleration=self.acceleration,
            do_add_noise=False,
            use_denoising_batch=self.use_denoising_batch,
            cfg_type='none',
            seed=self.seed,
        )

        self.stream.prepare(
            prompt=self.prompt,
            negative_prompt=self.negative_prompt,
            num_inference_steps=50,
            guidance_scale=self.guidance_scale,
            delta=self.delta,
            strength=self.strength,
        )

I use an instance of that class in my main script like so -- :

from diffusion_generator_controlnet import DiffusionGenerator
diffusion_generator = DiffusionGenerator(init_prompt=init_prompt, guidance_scale=guidance_scale, strength=strength)

text = "Male anime character wearing aviator sunglasses and a yacht captain hat drinking beer at an outdoor bar, GTA boss, GTA character art, tattoos, vaporwave, retrowave, 90s anime, 80s anime, retro anime, neon green, neon pink, neon orange, Tony Montana, miami vice" 

# diffusion_generator.stream() is actually called constantly in a while loop in my main script
# This is just an example for generating one image but it's used the same way in my real script
imgout_txt2img = diffusion_generator.stream(prompt=text)

My modified unet_step method in pipeline.py looks like this -- the variables I passed all the way back to pipeline.py from my top level scripts are "controlnet, o_image, pipe_t, openpose":

    def unet_step(
        self,
        x_t_latent: torch.Tensor,
        t_list: Union[torch.Tensor, list[int]],
        idx: Optional[int] = None,
        controlnet=None, 
        o_image=None, 
        pipe_t=None,        
        openpose=None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
            x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0)
            t_list = torch.concat([t_list[0:1], t_list], dim=0)
        elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
            x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0)
            t_list = torch.concat([t_list, t_list], dim=0)
        else:
            x_t_latent_plus_uc = x_t_latent

        # controlnet implementation
        a, b = self.get_a_b_control_net(
            x_t_latent_plus_uc, 
            self.prompt_embeds, 
            controlnet=controlnet, 
            o_image=o_image, 
            pipe_t=pipe_t, 
            openpose=openpose)

        model_pred = self.unet(
            x_t_latent_plus_uc,
            t_list,
            encoder_hidden_states=self.prompt_embeds,
            down_block_additional_residuals=a,
            mid_block_additional_residual=b,
            return_dict=False,
        )[0]

        # original implementation without controlnet
        # model_pred = self.unet(
        #     x_t_latent_plus_uc,
        #     t_list,
        #     encoder_hidden_states=self.prompt_embeds,
        #     return_dict=False,
        # )[0]

        # the rest of the code below this comment line is unchanged from Stream Diffusion's original code for this method
        if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
            noise_pred_text = model_pred[1:]
            self.stock_noise = torch.concat(
                [model_pred[0:1], self.stock_noise[1:]], dim=0
            )  # ここコメントアウトでself out cfg
        elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
            noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
        else:
            noise_pred_text = model_pred
        if self.guidance_scale > 1.0 and (
            self.cfg_type == "self" or self.cfg_type == "initialize"
        ):
            noise_pred_uncond = self.stock_noise * self.delta
        if self.guidance_scale > 1.0 and self.cfg_type != "none":
            model_pred = noise_pred_uncond + self.guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )
        else:
            model_pred = noise_pred_text

        # compute the previous noisy sample x_t -> x_t-1
        if self.use_denoising_batch:
            denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
            if self.cfg_type == "self" or self.cfg_type == "initialize":
                scaled_noise = self.beta_prod_t_sqrt * self.stock_noise
                delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx)
                alpha_next = torch.concat(
                    [
                        self.alpha_prod_t_sqrt[1:],
                        torch.ones_like(self.alpha_prod_t_sqrt[0:1]),
                    ],
                    dim=0,
                )
                delta_x = alpha_next * delta_x
                beta_next = torch.concat(
                    [
                        self.beta_prod_t_sqrt[1:],
                        torch.ones_like(self.beta_prod_t_sqrt[0:1]),
                    ],
                    dim=0,
                )
                delta_x = delta_x / beta_next
                init_noise = torch.concat(
                    [self.init_noise[1:], self.init_noise[0:1]], dim=0
                )
                self.stock_noise = init_noise + delta_x

        else:
            # denoised_batch = self.scheduler.step(model_pred, t_list[0], x_t_latent).denoised
            denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)

        return denoised_batch, model_pred

and then, finally, my modification of get_a_b_control_net that I added to pipeline.py as a method, so it can be used in unet_step of the same class:

    def get_a_b_control_net(self, a, b, controlnet=None, o_image=None, pipe_t=None, openpose=None):

        image = pipe_t.prepare_image(    
                image=openpose(o_image),
                width=512,
                height=512,
                batch_size=1 * 1,
                num_images_per_prompt=1,
                device="cuda",
                dtype=controlnet.dtype,
                do_classifier_free_guidance=False,
                guess_mode=False,            
            )

        down_block_res_samples, mid_block_res_sample = controlnet(
            a,
            801,
            encoder_hidden_states=b,
            controlnet_cond=image,
            conditioning_scale=1.7,
            guess_mode=False,
            return_dict=False,
        )
        return down_block_res_samples, mid_block_res_sample

Without the attempted controlnet modifications, my own class and script works just fine (I created it prior to your controlnet example here). I use the LCM_lora scheduler + LCMDreamShaper for both txt2img and img2img with the ability to easily adjust the denoise strength etc via LCM_Scheduler.

The args "controlnet, o_image, pipe_t, openpose" are all set to default to None in the methods, but they get passed in through the rest of the refactoring I did and I confirmed that they all make it to my modified get_a_b_control_net method intact, and that the values returned by get_a_b_control_net are not None/empty etc. It looks like it should work, but it just doesn't do anything to the output.

I have done a lot more debugging than just trying my implementation here as-is -- deleting and rebuilding the tensorrt engine every time, disabling use of LCM_Lora and using the vanilla sd 1.5 weights, enabling/disabling use_denoising_batch, trying different models, trying different batch size, trying a higher or lower conditioning_scale for controlnet, trying higher or lower settings for virtually all other adjustable params and args for StreamDiffusion -- the pipeline does txt2img without any error, but it does not apply any controlnet at all to the output..

I even tried to use a controlnet img2img implementation using StableDiffusionControlNetImg2ImgPipeline in place of your use of StableDiffusionControlNetPipeline, and did a bunch of associated refactoring of Stream Diffusion to get that integrated without error -- same result, img2img works perfectly but controlnet does nothing as implemented here.

I am going to try again with a clean venv today, using different SD models but if you can share your full modifications from your example (including your main script and config for realtime-text2image-with-control-net) I would be eternally grateful.

Either way, thank you so much for sharing this, and I am very hopeful to have it working soon!

WyattAutomation commented 6 months ago

An update: I can get it working with accelerstion set to Xformers, but not TensorRT. I am assuming I need to deep dive how to build ControlNet into the TensorRT engine compilation, seems like that may be a good bit more involved than just passing residuals to the unet_step method?

Xformers has me at about 9-12 FPS. Maybe I have something configured poorly, will dig in a bit more. Almost fast enough, if I can get it closer to 20fps I would be happy, but TensorRT would be ideal as I had it flying around 28-35 FPS previously

Liuqh12 commented 6 months ago

anything

  1. make sure you uninstall older python-package: streamdiffusion in your env(like conda environment). after update code, u must re-install it from your code.
  2. my code, as same as your.
  3. my base-model: sd1.5, control net model: openpose.

Use tensorrt as acc, reference Stable-Diffusion-WebUI-TensorRT in stable-diffusion-webui. I think they are very similar and improvement is obvious.

best wish.

WyattAutomation commented 6 months ago

anything

1. make sure you uninstall older python-package: streamdiffusion in your env(like conda environment). after update code, u must re-install it from your code.

2. my code, as same as your.

3. my base-model: [sd1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), control net model: [openpose](https://huggingface.co/lllyasviel/ControlNet).

Use tensorrt as acc, reference Stable-Diffusion-WebUI-TensorRT in stable-diffusion-webui. I think they are very similar and improvement is obvious.

best wish.

I think I (maybe) found out where and how to add the code needed to get ControlNet working with the tensorrt acceleration, which I think is probably why it worked with xformers and not tensorrt set as the accelerator (it just isn't a part of acceleration --> tensorrt).

Going to try implementing it this evening, I think I am close to having it working.

Also looking at SD WebUI as a reference is a great idea, I forgot they had tensorrt accel as a feature, thanks for the suggestion!

WyattAutomation commented 6 months ago

anything

1. make sure you uninstall older python-package: streamdiffusion in your env(like conda environment). after update code, u must re-install it from your code.

2. my code, as same as your.

3. my base-model: [sd1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), control net model: [openpose](https://huggingface.co/lllyasviel/ControlNet).

Use tensorrt as acc, reference Stable-Diffusion-WebUI-TensorRT in stable-diffusion-webui. I think they are very similar and improvement is obvious.

best wish.

Hey there, sorry to bug you again but I had a couple of questions on this before I start on another round of coding.

After several days of trial and error and using Claude3, GPT-4 and a couple of local LLMs as an "on the fly" set of documentation I have learned a great deal about TensorRT, ControlNet/Unet, Graph Surgeon, onnx, dynamic inputs for TRT and a lot more over the past few days from pursuing this, so, thank you again sharing enough breadcrumbs to actually get me started on creating my own diffusers pipelines and TensorRT models, engines, and build/export code and a learning how to implement a whole ton of useful utilities in my own code outside of this.

Anyway; I finally figured out how to add everything needed to the dynamic inputs for the Unet model that StreamDiffusion produces and finally got it building and recieving/sending inputs/outputs. That was a bit hard to figure out but now I know.

The issuse that remains is one that I pretty much expected -- the output images are either just latent noise or they are mising whole blocks of data (partial nans, somtimes they are whole images that convert to PIL properly but even then it just looks like the timesteps or something wasn't done correctly.

There is a really good chance that I am wrong about this, but since my ControlNetModel and my StableDiffusiongimg2imgControlNet instances are being instantiated outside of the context of the Engine classes that Unet and the VAEs use to call the ".infer()" method, I am assuming this is why I am having I/O issues even though my modifications to get the Unet2 Engine are building fine, and with fully functional input/outputs of the right size and dimensionality?

In summary, I have the additional 12 inputs/outputs for down_residuals and mid_residuals done and working for Unet to finally accept them as input, but now I am assuming I probably to create a ControlNet model, engine, and build/optimize functions/methods and integrate them now too?

Am I reasonable to suspect that I just need to go through all the same steps and integrate a new "CNet" class in models.py, "CNetEngine", a "compile_controlnet" method for wrapper.py to call, and make adjustments to the init code as well as everywhere else that will nee to instantiate or interarct wirh data that goes into ControlNet?

The get_input_profile, dynamic axes etc for a seperate tensorrt ControlNet outside of Unet should be easy to make now, Nvidia's SDWebUI plugin demonstrates this (albeit without any comments, or easy way to know how to run it outside of SDWebUI) ..

so my 3 questions are is mainly:

1) Is the right approach making a new ControlNet Engine in the StreamDiffusion code, seperate from Unet, with it's own trt I/O and optimization that has a ".forward" method called from it's associated trt ".infer()"?

2) Is the right approach to build the Controlnet model directly into Unet somehow?

Or

3) Should it technically be possible to pass the mid and down tensors from a normally loaded/unoptimized Control outside of TRT, to the TRT-optimized UNet loaded from an onnx, and I am just overlooking something silly/easy?

Also, one other item -- I am assuming the controlnet pipeline as well as the ControNetModel I am using from diffusers simply needs to be instantiated directly inside of whatever Engine class I designate for ControlNetEngine?

Thank you again!

WyattAutomation commented 6 months ago

If I have to add all of the CoontrolNet code, at least I the Unet efforts are still usable -- if I have to adapt a new ControlNet set of Classes/utilities to this I still have to have the Unet i/o working too I am assuming (which is done now)

menguzat commented 5 months ago

Hey there, any developments on this?

btw I'd really love to talk a bit about tensorrt and all it entails @WyattAutomation

did you try to build one for sdxl? I've been trying but have not yet gone down the tensorrt rabbithole like you did so maybe you have some pointers?

I managed to get sdxl to work with streamdiffusion -check this issue https://github.com/cumulo-autumn/StreamDiffusion/issues/114 - and someone created an sdxl fork - https://github.com/hkn-g/StreamDiffusion/tree/sdxl - but as yet, no luck with tensorrt sdxl.

WyattAutomation commented 5 months ago

Hey there, any developments on this?

btw I'd really love to talk a bit about tensorrt and all it entails @WyattAutomation

did you try to build one for sdxl? I've been trying but have not yet gone down the tensorrt rabbithole like you did so maybe you have some pointers?

I managed to get sdxl to work with streamdiffusion -check this issue #114 - and someone created an sdxl fork - https://github.com/hkn-g/StreamDiffusion/tree/sdxl - but as yet, no luck with tensorrt sdxl.

I added all of the ControlNet code for TensorRT and it finally compiles, the sizing for the dynamic inputs is all correct, and I did all of the painstaking work to get it integrated..

..And for some reason it refuses to flow through it correctly. I get big rectangular blocks of NaNs and the parts of the image that are not NaNs are latent noise.

I have tried adjusting all parameters and can only guess that it is maybe something broken with the image loading at this point; the exact same settings work just fine for ControlNet without TensorRT so it really makes no sense to me.

I can try to create a fork and a test script sometime this weekend so you can see where I am at. It is quite frustrating as it took about a week just getting it to compile. I really wish TensorRT had better debugging built-in.

zjysteven commented 5 months ago

@WyattAutomation Hey I'm interested in integrating controlnet into StreamDiffusion too. Did you finally manage to get anything? If so, would you mind sharing your code in your fork?

Liuqh12 commented 5 months ago

@WyattAutomation Hey I'm interested in integrating controlnet into StreamDiffusion too. Did you finally manage to get anything? If so, would you mind sharing your code in your fork?

+1

zjysteven commented 5 months ago

I managed to wrap controlnet with tensorrt, which is based on this post and TensorRT's official implementation. The idea is to combine unet and controlnet as a new custom model.

class UNet2DConditionControlNetModel(torch.nn.Module):
    def __init__(self, unet, controlnet) -> None:
        super().__init__()
        self.unet = unet
        self.controlnet = controlnet

    def forward(self, sample, timestep, encoder_hidden_states, image):
        # hard-coded since it is not clear how to integrate this argument into tensorrt
        conditioning_scale = 1.0

        down_samples, mid_sample = self.controlnet(
            sample,
            timestep,
            encoder_hidden_states=encoder_hidden_states,
            controlnet_cond=image,
            guess_mode=False,
            return_dict=False,
        )

        down_block_res_samples = [
            down_sample * conditioning_scale
            for down_sample in down_samples
        ]
        mid_block_res_sample = conditioning_scale * mid_sample\

        noise_pred = self.unet(
            sample,
            timestep,
            encoder_hidden_states=encoder_hidden_states,
            down_block_additional_residuals=down_block_res_samples,
            mid_block_additional_residual=mid_block_res_sample,
            return_dict=False,
        )
        return noise_pred

Then, we put new model class and engine for this custom model in tensorrt. Specifically, in src/streamdiffusion/acceleration/tensorrt/models.py we will have

class UNetControlNet(BaseModel):
    def __init__(
        self,
        fp16=False,
        device="cuda",
        max_batch_size=16,
        min_batch_size=1,
        embedding_dim=768,
        text_maxlen=77,
        unet_dim=4,
    ):
        super(UNetControlNet, self).__init__(
            fp16=fp16,
            device=device,
            max_batch_size=max_batch_size,
            min_batch_size=min_batch_size,
            embedding_dim=embedding_dim,
            text_maxlen=text_maxlen,
        )
        self.unet_dim = unet_dim
        self.name = "UNetControlNet"

    def get_input_names(self):
        return ["sample", "timestep", "encoder_hidden_states", "image"]

    def get_output_names(self):
        return ["latent"]

    def get_dynamic_axes(self):
        return {
            "sample": {0: "2B", 2: "H", 3: "W"},
            "timestep": {0: "2B"},
            "encoder_hidden_states": {0: "2B"},
            "image": {0: "2B", 2: '8H', 3: '8W'},
            "latent": {0: "2B", 2: "H", 3: "W"},
        }

    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
        (
            min_batch,
            max_batch,
            min_image_height, 
            max_image_height, 
            min_image_width, 
            max_image_width,
            min_latent_height,
            max_latent_height,
            min_latent_width,
            max_latent_width,
        ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
        return {
            "sample": [
                (min_batch, self.unet_dim, min_latent_height, min_latent_width),
                (batch_size, self.unet_dim, latent_height, latent_width),
                (max_batch, self.unet_dim, max_latent_height, max_latent_width),
            ],
            "timestep": [(min_batch,), (batch_size,), (max_batch,)],
            "encoder_hidden_states": [
                (min_batch, self.text_maxlen, self.embedding_dim),
                (batch_size, self.text_maxlen, self.embedding_dim),
                (max_batch, self.text_maxlen, self.embedding_dim),
            ],
            'image': [(min_batch, 3, min_image_height, min_image_width),
                       (batch_size, 3, image_height, image_width),
                       (max_batch, 3, max_image_height, max_image_width)]
        }

    def get_shape_dict(self, batch_size, image_height, image_width):
        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
        return {
            "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
            "timestep": (2 * batch_size,),
            "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
            "latent": (2 * batch_size, 4, latent_height, latent_width),
            "image": (2 * batch_size, 3, image_height, image_width),
        }

    def get_sample_input(self, batch_size, image_height, image_width):
        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
        dtype = torch.float16 if self.fp16 else torch.float32
        return (
            torch.randn(
                2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
            ),
            torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device),
            torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
            torch.randn(2 * batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device),
        )

in src/streamdiffusion/acceleration/tensorrt/engine.py we will have

class UNet2DConditionControlNetModelEngine:
    def __init__(
        self, 
        filepath: str, 
        stream: cuda.Stream, 
        use_cuda_graph: bool = False
    ):
        self.engine = Engine(filepath)
        self.stream = stream
        self.use_cuda_graph = use_cuda_graph

        self.engine.load()
        self.engine.activate()

    def __call__(
        self,
        latent_model_input: torch.Tensor,
        timestep: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        image,
        **kwargs,
    ) -> Any:
        if timestep.dtype != torch.float32:
            timestep = timestep.float()

        self.engine.allocate_buffers(
            shape_dict={
                "sample": latent_model_input.shape,
                "timestep": timestep.shape,
                "encoder_hidden_states": encoder_hidden_states.shape,
                "image": image.shape,
                "latent": latent_model_input.shape,
            },
            device=latent_model_input.device,
        )

        noise_pred = self.engine.infer(
            {
                "sample": latent_model_input,
                "timestep": timestep,
                "encoder_hidden_states": encoder_hidden_states,
                "image": image,
            },
            self.stream,
            use_cuda_graph=self.use_cuda_graph,
        )["latent"]
        return UNet2DConditionOutput(sample=noise_pred)

    def to(self, *args, **kwargs):
        pass

    def forward(self, *args, **kwargs):
        pass

both of which is adapted from the existing UNet and UNet2DConditionModelEngine.

WyattAutomation commented 5 months ago

I managed to wrap controlnet with tensorrt, which is based on this post and TensorRT's official implementation. The idea is to combine unet and controlnet as a new custom model.

class UNet2DConditionControlNetModel(torch.nn.Module):
    def __init__(self, unet, controlnet) -> None:
        super().__init__()
        self.unet = unet
        self.controlnet = controlnet

    def forward(self, sample, timestep, encoder_hidden_states, image):
        # hard-coded since it is not clear how to integrate this argument into tensorrt
        conditioning_scale = 1.0

        down_samples, mid_sample = self.controlnet(
            sample,
            timestep,
            encoder_hidden_states=encoder_hidden_states,
            controlnet_cond=image,
            guess_mode=False,
            return_dict=False,
        )

        down_block_res_samples = [
            down_sample * conditioning_scale
            for down_sample in down_samples
        ]
        mid_block_res_sample = conditioning_scale * mid_sample\

        noise_pred = self.unet(
            sample,
            timestep,
            encoder_hidden_states=encoder_hidden_states,
            down_block_additional_residuals=down_block_res_samples,
            mid_block_additional_residual=mid_block_res_sample,
            return_dict=False,
        )
        return noise_pred

Then, we put new model class and engine for this custom model in tensorrt. Specifically, in src/streamdiffusion/acceleration/tensorrt/models.py we will have

class UNetControlNet(BaseModel):
    def __init__(
        self,
        fp16=False,
        device="cuda",
        max_batch_size=16,
        min_batch_size=1,
        embedding_dim=768,
        text_maxlen=77,
        unet_dim=4,
    ):
        super(UNetControlNet, self).__init__(
            fp16=fp16,
            device=device,
            max_batch_size=max_batch_size,
            min_batch_size=min_batch_size,
            embedding_dim=embedding_dim,
            text_maxlen=text_maxlen,
        )
        self.unet_dim = unet_dim
        self.name = "UNetControlNet"

    def get_input_names(self):
        return ["sample", "timestep", "encoder_hidden_states", "image"]

    def get_output_names(self):
        return ["latent"]

    def get_dynamic_axes(self):
        return {
            "sample": {0: "2B", 2: "H", 3: "W"},
            "timestep": {0: "2B"},
            "encoder_hidden_states": {0: "2B"},
            "image": {0: "2B", 2: '8H', 3: '8W'},
            "latent": {0: "2B", 2: "H", 3: "W"},
        }

    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
        (
            min_batch,
            max_batch,
            min_image_height, 
            max_image_height, 
            min_image_width, 
            max_image_width,
            min_latent_height,
            max_latent_height,
            min_latent_width,
            max_latent_width,
        ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
        return {
            "sample": [
                (min_batch, self.unet_dim, min_latent_height, min_latent_width),
                (batch_size, self.unet_dim, latent_height, latent_width),
                (max_batch, self.unet_dim, max_latent_height, max_latent_width),
            ],
            "timestep": [(min_batch,), (batch_size,), (max_batch,)],
            "encoder_hidden_states": [
                (min_batch, self.text_maxlen, self.embedding_dim),
                (batch_size, self.text_maxlen, self.embedding_dim),
                (max_batch, self.text_maxlen, self.embedding_dim),
            ],
            'image': [(min_batch, 3, min_image_height, min_image_width),
                       (batch_size, 3, image_height, image_width),
                       (max_batch, 3, max_image_height, max_image_width)]
        }

    def get_shape_dict(self, batch_size, image_height, image_width):
        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
        return {
            "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
            "timestep": (2 * batch_size,),
            "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
            "latent": (2 * batch_size, 4, latent_height, latent_width),
            "image": (2 * batch_size, 3, image_height, image_width),
        }

    def get_sample_input(self, batch_size, image_height, image_width):
        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
        dtype = torch.float16 if self.fp16 else torch.float32
        return (
            torch.randn(
                2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
            ),
            torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device),
            torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
            torch.randn(2 * batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device),
        )

in src/streamdiffusion/acceleration/tensorrt/engine.py we will have

class UNet2DConditionControlNetModelEngine:
    def __init__(
        self, 
        filepath: str, 
        stream: cuda.Stream, 
        use_cuda_graph: bool = False
    ):
        self.engine = Engine(filepath)
        self.stream = stream
        self.use_cuda_graph = use_cuda_graph

        self.engine.load()
        self.engine.activate()

    def __call__(
        self,
        latent_model_input: torch.Tensor,
        timestep: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        image,
        **kwargs,
    ) -> Any:
        if timestep.dtype != torch.float32:
            timestep = timestep.float()

        self.engine.allocate_buffers(
            shape_dict={
                "sample": latent_model_input.shape,
                "timestep": timestep.shape,
                "encoder_hidden_states": encoder_hidden_states.shape,
                "image": image.shape,
                "latent": latent_model_input.shape,
            },
            device=latent_model_input.device,
        )

        noise_pred = self.engine.infer(
            {
                "sample": latent_model_input,
                "timestep": timestep,
                "encoder_hidden_states": encoder_hidden_states,
                "image": image,
            },
            self.stream,
            use_cuda_graph=self.use_cuda_graph,
        )["latent"]
        return UNet2DConditionOutput(sample=noise_pred)

    def to(self, *args, **kwargs):
        pass

    def forward(self, *args, **kwargs):
        pass

both of which is adapted from the existing UNet and UNet2DConditionModelEngine.

I am going to try this out soon; I ended up going with onediff's inference compiler over TRT but I have been wanting to test performance of trt nonetheless.

Thanks for sharing!

WyattAutomation commented 5 months ago

What kind of performance do you get on this? I was able to get a onediff compiled pipeline running at ~18FPS with LCM Lora, Tiny VAE, a custom prompt encoder, setting it up to only encode the prompt if it changes, 4 steps (DreamShaper7_LCM 1.5 model). That is using it with StableDiffusionControlNetImg2ImgPipeline, then onediff inference compiler. I get about 14FPS doing the same but with img2img controlnet.

zjysteven commented 5 months ago

What GPU are you using? My case is with SD1.5, 4 steps LCM sampling, controlnet_tile img2img, 512x512 generation, fp16, and tiny VAE. On A100 I can get ~22FPS with tensorrt.

WyattAutomation commented 5 months ago
  1. A100 is expected to outperform, I think 22FPS is too close to say TRT is faster, I can benchmark on my machine soon and see if it holds up though.
olegchomp commented 3 months ago

@zjysteven hi, i found your controlnet fork and tried to test it, but profile_controlnet.py with config from profile.sh allway fail with

Failed to parse ONNX model. Does the model file exist and contain a valid ONNX model?

also not sure that it should looks like that.

image

Can you suggest how to fix it depending on your repo?

menguzat commented 3 months ago

I think @dotsimulate had success with controlnets in the latest streamdiffusiontd tox

On Wed, May 29, 2024, 12:41 AM olegchomp @.***> wrote:

@zjysteven https://github.com/zjysteven hi, i found your controlnet fork and tried to test it, but profile_controlnet.py with config from profile.sh allway fail with

Failed to parse ONNX model. Does the model file exist and contain a valid ONNX model?

also not sure that it should looks like that.

image.png (view on web) https://github.com/cumulo-autumn/StreamDiffusion/assets/11017531/7ceadeb7-776e-4865-a735-f94b8076cbce

Can you suggest how to fix it depending on your repo?

— Reply to this email directly, view it on GitHub https://github.com/cumulo-autumn/StreamDiffusion/issues/132#issuecomment-2136155374, or unsubscribe https://github.com/notifications/unsubscribe-auth/A3EHWNVHTZTMXY7OXX6QL3TZET2XPAVCNFSM6AAAAABEIUFOOKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZWGE2TKMZXGQ . You are receiving this because you commented.Message ID: @.***>

olegchomp commented 3 months ago

@menguzat no, he didn't use accelerated controlnet

zjysteven commented 3 months ago

@olegchomp Not sure if I met the same bug but definitely try the setup section in my readme for installing dependencies. In my experience it was super sensitive to dependency versions.

dotsimulate commented 3 months ago

I haven't gotten controlnets running with tensorrt. I hope to get that working maybe next week if I have time.

zjysteven commented 3 months ago

@dotsimulate I have a version working with tensorrt in my fork. Feel free to check it out and let me know if it makes sense (feel like you guys are experienced and would appreciate comments if there are any).

olegchomp commented 3 months ago

@zjysteven did you test others controlnets? Something like canny or depth looks like broken. It's ignoring prompt with any weight provided and almost every time image looks like not guided img2img. Tried to reverse engineer your code and can't find what cause this issues.

zjysteven commented 3 months ago

@olegchomp No I didn't since my use case is exclusive to controlnet tile. It's weird though as it should work for any controlnet.

ThomasLengeling commented 2 months ago

@zjysteven thank you for repo!, I tested your version of the controlnet, and I'm also getting weird results using the video/URL example; if I change the prompt or other parameters, the resulting image does not change with a very weird pixel glitch.

zjysteven commented 2 months ago

@ThomasLengeling Hi like I mentioned above, although conceptually I don't see any reason why the code won't work across different controlnets, I only tested/used my code exclusively with controlnet tile and our specific use cases.

JetSimon commented 2 months ago

@olegchomp You may have been dumb like me and thought that you needed to make an image canny yourself. I think that StableDiffusionControlNetImg2ImgPipeline actually does that part for you unlike StableDiffusionControlNetPipeline which is what the canny control net docs use.

olegchomp commented 2 months ago

@olegchomp You may have been dumb like me and thought that you needed to make an image canny yourself. I think that StableDiffusionControlNetImg2ImgPipeline actually does that part for you unlike StableDiffusionControlNetPipeline which is what the canny control net docs use.

idk what happening, this is results from 8 steps (with 4 steps same results) with lcm & canny.

test 5 6 7