haofanwang / ControlNet-for-Diffusers

Transfer the ControlNet with any basemodel in diffusers🔥
MIT License
800 stars 47 forks source link

Question Regarding Multi Control #18

Open ghpkishore opened 1 year ago

ghpkishore commented 1 year ago

@haofanwang There are two pipelines, one for depth and one for openpose, lets call them pipe_control_depth, pipe_control_openpose.

Which pipeline should be used for generating the output? What does the input for the pipeline look like? Should the pipeline load both the models? Not sure how it works.

ghpkishore commented 1 year ago

Should I do this? Create a new folder : say control_sd15_depth_scribble_inpaint, where I first copy the control_sd15_depth, then replace the unet folder with that of inpaint , then add a folder called controlnet1 (as per your readme) which will be the scribble control net. Change the pipeline file to add controlnet1 as Unet2D, (basically the same as that of normal controlnet), initialize it to controlnet1, and then do the remaining steps? Is that how it works?

haofanwang commented 1 year ago

It doesn't matter which one you use as the base model. If you take pipe_control_depth as base, then just take out the control module from pipe_control_openpose, and add it to pipe_control_depth as a new module. Note that we need to modify the forward function in pipeline. Just same as single control pipeline, we add control weights and second control hint as extra input params.

ghpkishore commented 1 year ago

I am facing an issue in loading the other controlnet in the pipeline. Can you tell me how to do it?

ghpkishore commented 1 year ago

This worked.

base_model_id = "models/control_sd15_scribble" 
controlnet2 = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="controlnet",torch_dtype=torch.float16).to("cuda")

pipe_control = StableDiffusionControlNetInpaintPipeline.from_pretrained("models/control_sd15_depth_scribble_inpaint",controlnet2=controlnet2,torch_dtype=torch.float16).to('cuda')
ghpkishore commented 1 year ago

So I was able to get multi control net working in mine. Here is how i did it.

First I copied the unet from inpainting model and replaced the unet of control_sd15_depth model with it and called the new folder control_sd15_depth_inpaint.

Then I updated the current file "pipeline_stable_diffusion_controlnet_inpaint.py" to take in two control inputs and their weights.

After that I added controlnet2 to the pipe_control and set weights for the controls. It is now working.

controlnet2_path= "models/control_sd15_scribble"  # 
controlnet2 = UNet2DConditionModel.from_pretrained(controlnet2_path, subfolder="controlnet").to("cuda")
pipe_control = StableDiffusionControlNetInpaintPipeline.from_pretrained("models/control_sd15_depth_inpaint",controlnet2=controlnet2,torch_dtype=torch.float16).to('cuda')
pipe_control.unet.in_channels = 4
pipe_control.enable_attention_slicing()
output_image  = pipe_control(prompt=prompt, 
                                negative_prompt="human, hands, fingers, legs, body parts",
                                image=image,
                                mask_image=mask,
                                controlnet_hint1=control_image_1, 
                                controlnet_hint2=control_image_2, 
                                control1_weight=1,
                                control2_weight=0.2,
                                height=height,
                                width=width,
                                generator=generator,
                                num_inference_steps=100).images[0]
ghpkishore commented 1 year ago

@haofanwang you can test it out and if it works, can add it to the repo. Sorry I didn't make a PR. I do not know how, and i felt it is better to share first so that others can give it a go as well.

haofanwang commented 1 year ago

Nice job!

ghpkishore commented 1 year ago

@haofanwang I sent you a mail regarding the composer repo from DAMO and modelscope where the files already seem to exist. Given that I do not know Chinese I couldnt understand anything other than the fact that the code does exist there.

Can you check it out and see if it is working and all files are there? If so we can make a unofficial repo of the paper and work on integrating it with diffusers.

I think you can check out : https://www.modelscope.cn/models/damo/cv_composer_multi-modal-image-synthesis/files

ghpkishore commented 1 year ago

@haofanwang The code above for multi control net is incorrect. I reazlied that i never called the second controlnet. I deleted it.

ghpkishore commented 1 year ago

This on the other hand works:

# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import PIL.Image
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> from diffusers import StableDiffusionControlNetPipeline
        >>> from diffusers.utils import load_image

        >>> # Canny edged image for control
        >>> canny_edged_image = load_image(
        ...     "https://huggingface.co/takuma104/controlnet_dev/resolve/main/vermeer_canny_edged.png"
        ... )
        >>> pipe = StableDiffusionControlNetPipeline.from_pretrained("takuma104/control_sd15_canny").to("cuda")
        >>> image = pipe(prompt="best quality, extremely detailed", controlnet_hint=canny_edged_image).images[0]

"""

def prepare_mask_and_masked_image(image, mask): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to torch.Tensor with shapes batch x channels x height x width where channels is 3 for the image and 1 for the mask. The image will be converted to torch.float32 and normalized to be in [-1, 1]. The mask will be binarized (mask > 0.5) and cast to torch.float32 too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a PIL.Image, or a height x width x 3 np.array or a channels x height x width torch.Tensor or a batch x channels x height x width torch.Tensor. mask (type): The mask to apply to the image, i.e. regions to inpaint. It can be a PIL.Image, or a height x width np.array or a 1 x height x width torch.Tensor or a batch x 1 x height x width torch.Tensor. Raises: ValueError: torch.Tensor images should be in the [-1, 1] range. ValueError: torch.Tensor mask should be in the [0, 1] range. ValueError: mask and image should have the same spatial dimensions. TypeError: mask is a torch.Tensor but image is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, masked_image) as torch.Tensor with 4 dimensions: batch x channels x height x width. """ if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"image is a torch.Tensor but mask (type: {type(mask)} is not")

    # Batch single image
    if image.ndim == 3:
        assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
        image = image.unsqueeze(0)

    # Batch and add channel dim for single mask
    if mask.ndim == 2:
        mask = mask.unsqueeze(0).unsqueeze(0)

    # Batch single mask or add channel dim
    if mask.ndim == 3:
        # Single batched mask, no channel dim or single mask not batched but channel dim
        if mask.shape[0] == 1:
            mask = mask.unsqueeze(0)

        # Batched masks no channel dim
        else:
            mask = mask.unsqueeze(1)

    assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
    assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
    assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"

    # Check image is in [-1, 1]
    if image.min() < -1 or image.max() > 1:
        raise ValueError("Image should be in [-1, 1] range")

    # Check mask is in [0, 1]
    if mask.min() < 0 or mask.max() > 1:
        raise ValueError("Mask should be in [0, 1] range")

    # Binarize mask
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1

    # Image as float32
    image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
    raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
    # preprocess image
    if isinstance(image, (PIL.Image.Image, np.ndarray)):
        image = [image]

    if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
        image = [np.array(i.convert("RGB"))[None, :] for i in image]
        image = np.concatenate(image, axis=0)
    elif isinstance(image, list) and isinstance(image[0], np.ndarray):
        image = np.concatenate([i[None, :] for i in image], axis=0)

    image = image.transpose(0, 3, 1, 2)
    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

    # preprocess mask
    if isinstance(mask, (PIL.Image.Image, np.ndarray)):
        mask = [mask]

    if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
        mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
        mask = mask.astype(np.float32) / 255.0
    elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
        mask = np.concatenate([m[None, None, :] for m in mask], axis=0)

    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = torch.from_numpy(mask)

masked_image = image * (mask < 0.5)

return mask, masked_image

class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.

This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

Args:
    vae ([`AutoencoderKL`]):
        Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
    text_encoder ([`CLIPTextModel`]):
        Frozen text-encoder. Stable Diffusion uses the text portion of
        [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
        the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
    tokenizer (`CLIPTokenizer`):
        Tokenizer of class
        [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
    unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
    controlnet ([`UNet2DConditionModel`]):
        [ControlNet](https://arxiv.org/abs/2302.05543) architecture to generate guidance.
    scheduler ([`SchedulerMixin`]):
        A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
        [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
    safety_checker ([`StableDiffusionSafetyChecker`]):
        Classification module that estimates whether generated images could be considered offensive or harmful.
        Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
    feature_extractor ([`CLIPFeatureExtractor`]):
        Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""

def __init__(
    self,
    vae: AutoencoderKL,
    text_encoder: CLIPTextModel,
    tokenizer: CLIPTokenizer,
    unet: UNet2DConditionModel,
    controlnet: UNet2DConditionModel,
    controlnet2:UNet2DConditionModel,
    scheduler: KarrasDiffusionSchedulers,
    safety_checker: StableDiffusionSafetyChecker,
    feature_extractor: CLIPFeatureExtractor,
    requires_safety_checker: bool = False,
):
    super().__init__()

    self.register_modules(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        controlnet=controlnet,
        controlnet2=controlnet2,
        scheduler=scheduler,
        safety_checker=safety_checker,
        feature_extractor=feature_extractor,
    )
    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
    self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_vae_slicing(self):
    r"""
    Enable sliced VAE decoding.

    When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
    steps. This is useful to save some memory and allow larger batch sizes.
    """
    self.vae.enable_slicing()

def disable_vae_slicing(self):
    r"""
    Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
    computing decoding in one step.
    """
    self.vae.disable_slicing()

def enable_sequential_cpu_offload(self, gpu_id=0):
    r"""
    Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
    text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
    `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
    """
    if is_accelerate_available():
        from accelerate import cpu_offload
    else:
        raise ImportError("Please install accelerate via `pip install accelerate`")

    device = torch.device(f"cuda:{gpu_id}")

    for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
        cpu_offload(cpu_offloaded_model, device)

    if self.safety_checker is not None:
        cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)

@property
def _execution_device(self):
    r"""
    Returns the device on which the pipeline's models will be executed. After calling
    `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
    hooks.
    """
    if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
        return self.device
    for module in self.unet.modules():
        if (
            hasattr(module, "_hf_hook")
            and hasattr(module._hf_hook, "execution_device")
            and module._hf_hook.execution_device is not None
        ):
            return torch.device(module._hf_hook.execution_device)
    return self.device

def _encode_prompt(
    self,
    prompt,
    device,
    num_images_per_prompt,
    do_classifier_free_guidance,
    negative_prompt=None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
    r"""
    Encodes the prompt into text encoder hidden states.

    Args:
         prompt (`str` or `List[str]`, *optional*):
            prompt to be encoded
        device: (`torch.device`):
            torch device
        num_images_per_prompt (`int`):
            number of images that should be generated per prompt
        do_classifier_free_guidance (`bool`):
            whether to use classifier free guidance or not
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
            Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
            argument.
    """
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    if prompt_embeds is None:
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
            text_input_ids, untruncated_ids
        ):
            removed_text = self.tokenizer.batch_decode(
                untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
            )
            logger.warning(
                "The following part of your input was truncated because CLIP can only handle sequences up to"
                f" {self.tokenizer.model_max_length} tokens: {removed_text}"
            )

        if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
            attention_mask = text_inputs.attention_mask.to(device)
        else:
            attention_mask = None

        prompt_embeds = self.text_encoder(
            text_input_ids.to(device),
            attention_mask=attention_mask,
        )
        prompt_embeds = prompt_embeds[0]

    prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

    bs_embed, seq_len, _ = prompt_embeds.shape
    # duplicate text embeddings for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

    # get unconditional embeddings for classifier free guidance
    if do_classifier_free_guidance and negative_prompt_embeds is None:
        uncond_tokens: List[str]
        if negative_prompt is None:
            uncond_tokens = [""] * batch_size
        elif type(prompt) is not type(negative_prompt):
            raise TypeError(
                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                f" {type(prompt)}."
            )
        elif isinstance(negative_prompt, str):
            uncond_tokens = [negative_prompt]
        elif batch_size != len(negative_prompt):
            raise ValueError(
                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                " the batch size of `prompt`."
            )
        else:
            uncond_tokens = negative_prompt

        max_length = prompt_embeds.shape[1]
        uncond_input = self.tokenizer(
            uncond_tokens,
            padding="max_length",
            max_length=max_length,
            truncation=True,
            return_tensors="pt",
        )

        if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
            attention_mask = uncond_input.attention_mask.to(device)
        else:
            attention_mask = None

        negative_prompt_embeds = self.text_encoder(
            uncond_input.input_ids.to(device),
            attention_mask=attention_mask,
        )
        negative_prompt_embeds = negative_prompt_embeds[0]

    if do_classifier_free_guidance:
        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
        seq_len = negative_prompt_embeds.shape[1]

        negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

    return prompt_embeds

def run_safety_checker(self, image, device, dtype):
    if self.safety_checker is not None:
        safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
        image, has_nsfw_concept = self.safety_checker(
            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
        )
    else:
        has_nsfw_concept = None
    return image, has_nsfw_concept

def decode_latents(self, latents):
    latents = 1 / self.vae.config.scaling_factor * latents
    image = self.vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
    return image

def prepare_extra_step_kwargs(self, generator, eta):
    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
    # and should be between [0, 1]

    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
    extra_step_kwargs = {}
    if accepts_eta:
        extra_step_kwargs["eta"] = eta

    # check if the scheduler accepts generator
    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
    if accepts_generator:
        extra_step_kwargs["generator"] = generator
    return extra_step_kwargs

def decode_latents(self, latents):
    latents = 1 / self.vae.config.scaling_factor * latents
    image = self.vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
    return image

def check_inputs(
    self,
    prompt,
    height,
    width,
    callback_steps,
    negative_prompt=None,
    prompt_embeds=None,
    negative_prompt_embeds=None,
):
    if height % 8 != 0 or width % 8 != 0:
        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

    if (callback_steps is None) or (
        callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
    ):
        raise ValueError(
            f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
            f" {type(callback_steps)}."
        )

    if prompt is not None and prompt_embeds is not None:
        raise ValueError(
            f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
            " only forward one of the two."
        )
    elif prompt is None and prompt_embeds is None:
        raise ValueError(
            "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
        )
    elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

    if negative_prompt is not None and negative_prompt_embeds is not None:
        raise ValueError(
            f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
            f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
        )

    if prompt_embeds is not None and negative_prompt_embeds is not None:
        if prompt_embeds.shape != negative_prompt_embeds.shape:
            raise ValueError(
                "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                f" {negative_prompt_embeds.shape}."
            )

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
    shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )

    if latents is None:
        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    else:
        latents = latents.to(device)

    # scale the initial noise by the standard deviation required by the scheduler
    latents = latents * self.scheduler.init_noise_sigma
    return latents

def controlnet_hint_conversion(self, controlnet_hint, height, width, num_images_per_prompt):
    channels = 3
    if isinstance(controlnet_hint, torch.Tensor):
        # torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
        shape_chw = (channels, height, width)
        shape_bchw = (1, channels, height, width)
        shape_nchw = (num_images_per_prompt, channels, height, width)
        if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
            controlnet_hint = controlnet_hint.to(dtype=self.controlnet.dtype, device=self.controlnet.device)
            if controlnet_hint.shape != shape_nchw:
                controlnet_hint = controlnet_hint.repeat(num_images_per_prompt, 1, 1, 1)
            return controlnet_hint
        else:
            raise ValueError(
                f"Acceptble shape of `controlnet_hint` are any of ({channels}, {height}, {width}),"
                + f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
                + f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
            )
    elif isinstance(controlnet_hint, np.ndarray):
        # np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
        # hwc is opencv compatible image format. Color channel must be BGR Format.
        if controlnet_hint.shape == (height, width):
            controlnet_hint = np.repeat(controlnet_hint[:, :, np.newaxis], channels, axis=2)  # hw -> hwc(c==3)
        shape_hwc = (height, width, channels)
        shape_bhwc = (1, height, width, channels)
        shape_nhwc = (num_images_per_prompt, height, width, channels)
        if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
            controlnet_hint = torch.from_numpy(controlnet_hint.copy())
            controlnet_hint = controlnet_hint.to(dtype=self.controlnet.dtype, device=self.controlnet.device)
            controlnet_hint /= 255.0
            if controlnet_hint.shape != shape_nhwc:
                controlnet_hint = controlnet_hint.repeat(num_images_per_prompt, 1, 1, 1)
            controlnet_hint = controlnet_hint.permute(0, 3, 1, 2)  # b h w c -> b c h w
            return controlnet_hint
        else:
            raise ValueError(
                f"Acceptble shape of `controlnet_hint` are any of ({width}, {channels}), "
                + f"({height}, {width}, {channels}), "
                + f"(1, {height}, {width}, {channels}) or "
                + f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
            )
    elif isinstance(controlnet_hint, PIL.Image.Image):
        if controlnet_hint.size == (width, height):
            controlnet_hint = controlnet_hint.convert("RGB")  # make sure 3 channel RGB format
            controlnet_hint = np.array(controlnet_hint)  # to numpy
            controlnet_hint = controlnet_hint[:, :, ::-1]  # RGB -> BGR
            return self.controlnet_hint_conversion(controlnet_hint, height, width, num_images_per_prompt)
        else:
            raise ValueError(
                f"Acceptable image size of `controlnet_hint` is ({width}, {height}) but is {controlnet_hint.size}"
            )
    else:
        raise ValueError(
            f"Acceptable type of `controlnet_hint` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
        )
def controlnet_hint_conversion2(self, controlnet_hint2, height, width, num_images_per_prompt):
    channels = 3
    if isinstance(controlnet_hint2, torch.Tensor):
        # torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
        shape_chw = (channels, height, width)
        shape_bchw = (1, channels, height, width)
        shape_nchw = (num_images_per_prompt, channels, height, width)
        if controlnet_hint2.shape in [shape_chw, shape_bchw, shape_nchw]:
            controlnet_hint2 = controlnet_hint2.to(dtype=self.controlnet2.dtype, device=self.controlnet2.device)
            if controlnet_hint2.shape != shape_nchw:
                controlnet_hint2 = controlnet_hint2.repeat(num_images_per_prompt, 1, 1, 1)
            return controlnet_hint2
        else:
            raise ValueError(
                f"Acceptble shape of `controlnet_hint` are any of ({channels}, {height}, {width}),"
                + f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
                + f"{channels}, {height}, {width}) but is {controlnet_hint2.shape}"
            )
    elif isinstance(controlnet_hint2, np.ndarray):
        # np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
        # hwc is opencv compatible image format. Color channel must be BGR Format.
        if controlnet_hint2.shape == (height, width):
            controlnet_hint2 = np.repeat(controlnet_hint2[:, :, np.newaxis], channels, axis=2)  # hw -> hwc(c==3)
        shape_hwc = (height, width, channels)
        shape_bhwc = (1, height, width, channels)
        shape_nhwc = (num_images_per_prompt, height, width, channels)
        if controlnet_hint2.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
            controlnet_hint2 = torch.from_numpy(controlnet_hint2.copy())
            controlnet_hint2 = controlnet_hint2.to(dtype=self.controlnet2.dtype, device=self.controlnet2.device)
            controlnet_hint2 /= 255.0
            if controlnet_hint2.shape != shape_nhwc:
                controlnet_hint2 = controlnet_hint2.repeat(num_images_per_prompt, 1, 1, 1)
            controlnet_hint2 = controlnet_hint2.permute(0, 3, 1, 2)  # b h w c -> b c h w
            return controlnet_hint2
        else:
            raise ValueError(
                f"Acceptble shape of `controlnet_hint` are any of ({width}, {channels}), "
                + f"({height}, {width}, {channels}), "
                + f"(1, {height}, {width}, {channels}) or "
                + f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint2.shape}"
            )
    elif isinstance(controlnet_hint2, PIL.Image.Image):
        if controlnet_hint2.size == (width, height):
            controlnet_hint2 = controlnet_hint2.convert("RGB")  # make sure 3 channel RGB format
            controlnet_hint2 = np.array(controlnet_hint2)  # to numpy
            controlnet_hint2 = controlnet_hint2[:, :, ::-1]  # RGB -> BGR
            return self.controlnet_hint_conversion2(controlnet_hint2, height, width, num_images_per_prompt)
        else:
            raise ValueError(
                f"Acceptable image size of `controlnet_hint` is ({width}, {height}) but is {controlnet_hint2.size}"
            )
    else:
        raise ValueError(
            f"Acceptable type of `controlnet_hint` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint2)}"
        )

def prepare_mask_latents(
    self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
    # resize the mask to latents shape as we concatenate the mask to the latents
    # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
    # and half precision
    mask = torch.nn.functional.interpolate(
        mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
    )
    mask = mask.to(device=device, dtype=dtype)

    masked_image = masked_image.to(device=device, dtype=dtype)

    # encode the mask image into latents space so we can concatenate it to the latents
    if isinstance(generator, list):
        masked_image_latents = [
            self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
            for i in range(batch_size)
        ]
        masked_image_latents = torch.cat(masked_image_latents, dim=0)
    else:
        masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
    masked_image_latents = self.vae.config.scaling_factor * masked_image_latents

    # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
    if mask.shape[0] < batch_size:
        if not batch_size % mask.shape[0] == 0:
            raise ValueError(
                "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
                f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
                " of masks that you pass is divisible by the total requested batch size."
            )
        mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
    if masked_image_latents.shape[0] < batch_size:
        if not batch_size % masked_image_latents.shape[0] == 0:
            raise ValueError(
                "The passed images and the required batch size don't match. Images are supposed to be duplicated"
                f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
                " Make sure the number of images that you pass is divisible by the total requested batch size."
            )
        masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)

    mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
    masked_image_latents = (
        torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
    )

    # aligning device to prevent device errors when concating it with the latent model input
    masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
    return mask, masked_image_latents

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
    self,
    prompt: Union[str, List[str]] = None,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 50,
    guidance_scale: float = 7.5,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    eta: float = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
    callback_steps: Optional[int] = 1,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    controlnet_hint1: Optional[Union[torch.FloatTensor, np.ndarray, PIL.Image.Image]] = None,
    controlnet_hint2: Optional[Union[torch.FloatTensor, np.ndarray, PIL.Image.Image]] = None,
    image: Union[torch.FloatTensor, PIL.Image.Image] = None,
    mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
    control1_weight: Optional[float] = 1.0,
    control2_weight: Optional[float] = 1.0,
):
    r"""
    Function invoked when calling the pipeline for generation.

    Args:
        prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
            instead.
        height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
            The height in pixels of the generated image.
        width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
            The width in pixels of the generated image.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        guidance_scale (`float`, *optional*, defaults to 7.5):
            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
            `guidance_scale` is defined as `w` of equation 2. of [Imagen
            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
            usually at the expense of lower image quality.
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
            Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
            [`schedulers.DDIMScheduler`], will be ignored for others.
        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
            to make generation deterministic.
        latents (`torch.FloatTensor`, *optional*):
            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
            tensor will ge generated by sampling using the supplied random `generator`.
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
            argument.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generate image. Choose between
            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
            plain tuple.
        callback (`Callable`, *optional*):
            A function that will be called every `callback_steps` steps during inference. The function will be
            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
        callback_steps (`int`, *optional*, defaults to 1):
            The frequency at which the `callback` function will be called. If not specified, the callback will be
            called at every step.
        cross_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
            `self.processor` in
            [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
        controlnet_hint (`torch.FloatTensor`, `np.ndarray` or `PIL.Image.Image`, *optional*):
            ControlNet input embedding. ControlNet generates guidances using this input embedding. If the type is
            specified as `torch.FloatTensor`, it is passed to ControlNet as is. If the type is `np.ndarray`, it is
            assumed to be an OpenCV compatible image format. PIL.Image.Image` can also be accepted as an image. The
            size of all these types must correspond to the output image size.

    Examples:

    Returns:
        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
        When returning a tuple, the first element is a list with the generated images, and the second element is a
        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
        (nsfw) content, according to the `safety_checker`.
    """
    # 0. Default height and width to unet
    height = height or self.unet.config.sample_size * self.vae_scale_factor
    width = width or self.unet.config.sample_size * self.vae_scale_factor

    # 1. Control Embedding check & conversion
    if controlnet_hint1 is not None:
        controlnet_hint1 = self.controlnet_hint_conversion(controlnet_hint1, height, width, num_images_per_prompt)
    if controlnet_hint2 is not None:
        controlnet_hint2 = self.controlnet_hint_conversion2(controlnet_hint2, height, width, num_images_per_prompt)

    # 2. Check inputs. Raise error if not correct
    self.check_inputs(
        prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
    )

    # 3. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = self._execution_device
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = guidance_scale > 1.0

    # 4. Encode input prompt
    prompt_embeds = self._encode_prompt(
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
    )

    mask, masked_image = prepare_mask_and_masked_image(image, mask_image)

    # 5. Prepare timesteps
    self.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = self.scheduler.timesteps

    # 6. Prepare latent variables
    num_channels_latents = self.unet.in_channels
    latents = self.prepare_latents(
        batch_size * num_images_per_prompt,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )

    mask, masked_image_latents = self.prepare_mask_latents(
        mask,
        masked_image,
        batch_size * num_images_per_prompt,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        do_classifier_free_guidance,
    )

    num_channels_mask = mask.shape[1]
    num_channels_masked_image = masked_image_latents.shape[1]

    # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

    # 8. Denoising loop
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
            if controlnet_hint1 is not None or controlnet_hint2 is not None:

                merged_control = []

                if controlnet_hint1 is not None:
                    # ControlNet predict the noise residual
                    control1 = self.controlnet(
                        latent_model_input, t, encoder_hidden_states=prompt_embeds, controlnet_hint=controlnet_hint1
                    )

                if controlnet_hint2 is not None:    
                    control2 = self.controlnet2(
                        latent_model_input, t, encoder_hidden_states=prompt_embeds, controlnet_hint=controlnet_hint2
                    )

                if controlnet_hint1 is not None and controlnet_hint2 is not None:
                    for i in range(len(control1)):
                        merged_control.append(control1_weight*control1[i]+control2_weight*control2[i])                                                
                    control = merged_control

                elif controlnet_hint1 is not None and controlnet_hint2 is None:
                     control = control1

                elif controlnet_hint1 is None and controlnet_hint2 is not None:
                    control = control2

                control = [item for item in control]
                latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                    control=control,
                ).sample
            else:
                # predict the noise residual
                latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                ).sample

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)

    if output_type == "latent":
        image = latents
        has_nsfw_concept = None
    elif output_type == "pil":
        # 8. Post-processing
        image = self.decode_latents(latents)

        # 9. Run safety checker
        image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)

        # 10. Convert to PIL
        image = self.numpy_to_pil(image)
    else:
        # 8. Post-processing
        image = self.decode_latents(latents)

        # 9. Run safety checker
        image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)

    if not return_dict:
        return (image, has_nsfw_concept)

    return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
ghpkishore commented 1 year ago

The way I checked it is, I kept all the inputs constant and tried with controlhint_1 keeping controlhint_2 empty and vice versa. I got different images both times. Therefore, My understanding was that I was able to get multi control working correctly.

jiachen0212 commented 1 month ago

The way I checked it is, I kept all the inputs constant and tried with controlhint_1 keeping controlhint_2 empty and vice versa. I got different images both times. Therefore, My understanding was that I was able to get multi control working correctly.

Did you develop it based on controlnet:https://github.com/lllyasviel/ControlNet.git? Is it OK to use two control quantities together?

jiachen0212 commented 1 month ago

The way I checked it is, I kept all the inputs constant and tried with controlhint_1 keeping controlhint_2 empty and vice versa. I got different images both times. Therefore, My understanding was that I was able to get multi control working correctly.

Hello, I would like to ask about the use of this repo for multi-control. Do I directly use two control_xx_net.pth files and merge the results during inference? No need to do training? Also, when I run this repo, it seems that there is an environment problem... Would you consider making a repo to show how to do multi-training?