huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.03k stars 5.17k forks source link

SDXL Image2Image Controlnet pipeline fails with cpu offload #7180

Closed odusseys closed 6 months ago

odusseys commented 6 months ago

Describe the bug

I am running SDXL-lightning with a canny edge controlnet. Running on a T4 (16G VRAM). I am using enable_model_cpu_offload to reduce memory usage, but I am running into the following error:

mat1 and mat2 must have the same dtype

I am guessing something is wrong with the offloading of some components, leading to incompatible dtypes in weights and inputs. I am only using PIL images as inputs, and thus would expect the pipeline to handle type conversions.

Reproduction

from PIL import Image
from diffusers import AutoPipelineForImage2Image, EulerDiscreteScheduler, ControlNetModel, UNet2DConditionModel
import cv2
import numpy as np
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

controlnet_edges_model ="diffusers/controlnet-canny-sdxl-1.0"

num_inference_steps = 2

controlnet_edges = ControlNetModel.from_pretrained(
    controlnet_edges_model,
    torch_dtype=torch.float16
)

base_model = "stabilityai/stable-diffusion-xl-base-1.0"
model_repo = "ByteDance/SDXL-Lightning"
model_ckpt = f"sdxl_lightning_{num_inference_steps}step_unet.safetensors"

# Load model.
unet = UNet2DConditionModel.from_config(base_model, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(model_repo, model_ckpt)))
img2img_pipeline = AutoPipelineForImage2Image.from_pretrained(base_model, unet=unet, torch_dtype=torch.float16, variant="fp16", controlnet=[controlnet_edges])
img2img_pipeline.scheduler = EulerDiscreteScheduler.from_config(img2img_pipeline.scheduler.config, timestep_spacing="trailing")
img2img_pipeline.unet.to(memory_format=torch.channels_last)
img2img_pipeline.enable_model_cpu_offload()

def compute_edges(image, min_thresh=100, max_thresh=200):
  image = cv2.Canny(np.array(image), min_thresh, max_thresh)
  image = image[:, :, None]
  image = np.concatenate([image, image, image], axis=2)
  return Image.fromarray(image)

image = Image.open("") #PIL Image here, sorry for not being able to provide one for this snippet 

IMAGE_SIZE = 1024

with torch.no_grad():
    image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
    edges = compute_edges(image).resize((IMAGE_SIZE, IMAGE_SIZE))
    res = img2img_pipeline(prompt,
                      image=image,
                      control_image=[edges]
                      num_inference_steps=num_inference_steps,
                      strength=0.9,
                      guidance_scale=0.0,
                      controlnet_conditioning_scale=[0.5]
                            ).images[0]

Logs

RuntimeError                              Traceback (most recent call last)
<ipython-input-11-3fb3dc3b2103> in <cell line: 9>()
     20 
     21     # display(edges)
---> 22     denoised = denoise(prompt, rendered, edges, depth, strength=1.0)
     23     # remove background
     24     img = np.array(denoised.resize((RENDERED_IMAGE_SIZE, RENDERED_IMAGE_SIZE)))

9 frames
<ipython-input-5-de2165de7010> in denoise(prompt, image, edges, depth, strength)
     17     edges = edges.resize((IMAGE_SIZE, IMAGE_SIZE))
     18     prompt = "professional photograph of " + prompt + ", " + realistic
---> 19     res = img2img_pipeline(prompt,
     20                       negative_prompt=negative,
     21                       image=image,

/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py in __call__(self, prompt, prompt_2, image, control_image, height, width, strength, num_inference_steps, guidance_scale, negative_prompt, negative_prompt_2, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, output_type, return_dict, cross_attention_kwargs, controlnet_conditioning_scale, guess_mode, control_guidance_start, control_guidance_end, original_size, crops_coords_top_left, target_size, negative_original_size, negative_crops_coords_top_left, negative_target_size, aesthetic_score, negative_aesthetic_score, clip_skip, callback_on_step_end, callback_on_step_end_tensor_inputs, **kwargs)
   1530 
   1531                 # predict the noise residual
-> 1532                 noise_pred = self.unet(
   1533                     latent_model_input,
   1534                     t,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
    164                 output = module._old_forward(*args, **kwargs)
    165         else:
--> 166             output = module._old_forward(*args, **kwargs)
    167         return module._hf_hook.post_forward(module, output)
    168 

/usr/local/lib/python3.10/dist-packages/diffusers/models/unets/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict)
   1150         # 1. time
   1151         t_emb = self.get_time_embed(sample=sample, timestep=timestep)
-> 1152         emb = self.time_embedding(t_emb, timestep_cond)
   1153         aug_emb = None
   1154 

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.10/dist-packages/diffusers/models/embeddings.py in forward(self, sample, condition)
    226         if condition is not None:
    227             sample = sample + self.cond_proj(condition)
--> 228         sample = self.linear_1(sample)
    229 
    230         if self.act is not None:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input)
    112 
    113     def forward(self, input: Tensor) -> Tensor:
--> 114         return F.linear(input, self.weight, self.bias)
    115 
    116     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 must have the same dtype

System Info

Who can help?

@yiyixuxu @sayakpaul @DN6

sayakpaul commented 6 months ago

Got it working here: https://colab.research.google.com/gist/sayakpaul/0197712664287f2548844c7e90edf298/scratchpad.ipynb.

Main fix:

unet.load_state_dict(load_file(hf_hub_download(model_repo, model_ckpt)))
+ unet.to(dtype=torch.float16)