huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.21k stars 5.4k forks source link

Stable video diffusion inpainting pipeline implementation issues #5990

Closed caoandong closed 10 months ago

caoandong commented 11 months ago

Model/Pipeline/Scheduler description

Hi, thank you for integrating stable video diffusion pipeline. I tried to implement a simple inpainting pipeline, inspired by the legacy inpainting pipeline, but encountered issue in adding noise to the inpainting latents.

In particular, the scheduler.add_noise function seem to add the incorrect noise to the latents. I suspect the issue have something to do with the special noise scheduler proposed in the SVD paper; in particular in Section D.2., the author modified the preconditioning functions and distribution over the training noise levels, do we need to modify the EulerDiscreteScheduler to accomodate such modification?

The following is my sketchy implementation of the inpainting pipline:


from dataclasses import dataclass
from typing import List, Optional, Union

import numpy as np
import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import PIL_INTERPOLATION, BaseOutput, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

def preprocess_image(image, batch_size):
    w, h = image.size
    w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8
    image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
    image = np.array(image).astype(np.float32) / 255.0
    image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0

def preprocess_mask(mask, batch_size, num_frames, scale_factor=8):
    if not isinstance(mask, torch.FloatTensor):
        mask = mask.convert("L")
        w, h = mask.size
        w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8
        mask = mask.resize((w // scale_factor, h // scale_factor),
                           resample=PIL_INTERPOLATION["lanczos"])
        mask = np.array(mask).astype(np.float32) / 255.0
        mask = np.tile(mask, (4, 1, 1))
        mask = np.vstack([mask[None]] * num_frames)
        mask = np.vstack([mask[None]] * batch_size)

        # mask = 1 - mask  # repaint white, keep black
        mask = torch.from_numpy(mask)

        target_shape = (batch_size, num_frames, 4, h //
                        scale_factor, w // scale_factor)

        if mask.shape != target_shape:
            raise ValueError(f"Make shape {mask.shape} != {target_shape}")

        return mask

    else:
        valid_mask_channel_sizes = [1, 3]
        # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
        if mask.shape[3] in valid_mask_channel_sizes:
            mask = mask.permute(0, 3, 1, 2)
        elif mask.shape[1] not in valid_mask_channel_sizes:
            raise ValueError(
                f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
                f" but received mask of shape {tuple(mask.shape)}"
            )
        # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
        mask = mask.mean(dim=1, keepdim=True)
        h, w = mask.shape[-2:]
        h, w = (x - x % 8 for x in (h, w))  # resize to integer multiple of 8
        mask = torch.nn.functional.interpolate(
            mask, (h // scale_factor, w // scale_factor))
        return mask

def _append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(
            f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
    return x[(...,) + (None,) * dims_to_append]

def tensor2vid(video: torch.Tensor, processor, output_type="np"):
    # Based on:
    # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78

    batch_size, channels, num_frames, height, width = video.shape
    outputs = []
    for batch_idx in range(batch_size):
        batch_vid = video[batch_idx].permute(1, 0, 2, 3)
        batch_output = processor.postprocess(batch_vid, output_type)

        outputs.append(batch_output)

    return outputs

@dataclass
class StableVideoDiffusionPipelineOutput(BaseOutput):
    r"""
    Output class for zero-shot text-to-video pipeline.

    Args:
        frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
            num_channels)`.
    """

    frames: Union[List[PIL.Image.Image], np.ndarray]

class StableVideoDiffusionInpaintingPipeline(DiffusionPipeline):
    r"""
    Pipeline to generate video from an input image using Stable Video Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
            Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
        unet ([`UNetSpatioTemporalConditionModel`]):
            A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
        scheduler ([`EulerDiscreteScheduler`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            A `CLIPImageProcessor` to extract features from generated images.
    """

    model_cpu_offload_seq = "image_encoder->unet->vae"

    def __init__(
        self,
        vae: AutoencoderKLTemporalDecoder,
        image_encoder: CLIPVisionModelWithProjection,
        unet: UNetSpatioTemporalConditionModel,
        scheduler: EulerDiscreteScheduler,
        feature_extractor: CLIPImageProcessor,
    ):
        super().__init__()

        self.register_modules(
            vae=vae,
            image_encoder=image_encoder,
            unet=unet,
            scheduler=scheduler,
            feature_extractor=feature_extractor,
        )
        self.vae_scale_factor = 2 ** (
            len(self.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(
            vae_scale_factor=self.vae_scale_factor)

    def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
        dtype = next(self.image_encoder.parameters()).dtype

        if not isinstance(image, torch.Tensor):
            image = self.image_processor.pil_to_numpy(image)
            image = self.image_processor.numpy_to_pt(image)

            # We normalize the image before resizing to match with the original implementation.
            # Then we unnormalize it after resizing.
            image = image * 2.0 - 1.0
            image = _resize_with_antialiasing(image, (224, 224))
            image = (image + 1.0) / 2.0

            # Normalize the image with for CLIP input
            image = self.feature_extractor(
                images=image,
                do_normalize=True,
                do_center_crop=False,
                do_resize=False,
                do_rescale=False,
                return_tensors="pt",
            ).pixel_values

        image = image.to(device=device, dtype=dtype)
        image_embeddings = self.image_encoder(image).image_embeds
        image_embeddings = image_embeddings.unsqueeze(1)

        # duplicate image embeddings for each generation per prompt, using mps friendly method
        bs_embed, seq_len, _ = image_embeddings.shape
        image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
        image_embeddings = image_embeddings.view(
            bs_embed * num_videos_per_prompt, seq_len, -1)

        if do_classifier_free_guidance:
            negative_image_embeddings = torch.zeros_like(image_embeddings)

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            image_embeddings = torch.cat(
                [negative_image_embeddings, image_embeddings])

        return image_embeddings

    def _encode_vae_image(
        self,
        image: torch.Tensor,
        device,
        num_videos_per_prompt,
        do_classifier_free_guidance,
    ):
        image = image.to(device=device)
        image_latents = self.vae.encode(image).latent_dist.mode()

        if do_classifier_free_guidance:
            negative_image_latents = torch.zeros_like(image_latents)

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            image_latents = torch.cat([negative_image_latents, image_latents])

        # duplicate image_latents for each generation per prompt, using mps friendly method
        image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)

        return image_latents

    def _get_add_time_ids(
        self,
        fps,
        motion_bucket_id,
        noise_aug_strength,
        dtype,
        batch_size,
        num_videos_per_prompt,
        do_classifier_free_guidance,
    ):
        add_time_ids = [fps, motion_bucket_id, noise_aug_strength]

        passed_add_embed_dim = self.unet.config.addition_time_embed_dim * \
            len(add_time_ids)
        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features

        if expected_add_embed_dim != passed_add_embed_dim:
            raise ValueError(
                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
            )

        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
        add_time_ids = add_time_ids.repeat(
            batch_size * num_videos_per_prompt, 1)

        if do_classifier_free_guidance:
            add_time_ids = torch.cat([add_time_ids, add_time_ids])

        return add_time_ids

    def decode_latents(self, latents, num_frames, decode_chunk_size=14):
        # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
        latents = latents.flatten(0, 1)

        latents = 1 / self.vae.config.scaling_factor * latents

        # decode decode_chunk_size frames at a time to avoid OOM
        frames = []
        for i in range(0, latents.shape[0], decode_chunk_size):
            num_frames_in = latents[i: i + decode_chunk_size].shape[0]
            frame = self.vae.decode(
                latents[i: i + decode_chunk_size], num_frames_in).sample
            frames.append(frame)
        frames = torch.cat(frames, dim=0)

        # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
        frames = frames.reshape(-1, num_frames, *
                                frames.shape[1:]).permute(0, 2, 1, 3, 4)

        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        frames = frames.float()
        return frames

    def check_inputs(self, image, height, width):
        if (
            not isinstance(image, torch.Tensor)
            and not isinstance(image, PIL.Image.Image)
            and not isinstance(image, list)
        ):
            raise ValueError(
                "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
                f" {type(image)}"
            )

        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(
                f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

    def _prepare_latents_with_init_image(
        self,
        image,
        timestep,
        batch_size,
        num_frames,
        num_channels_latents,
        height,
        width,
        dtype,
        device,
        generator,
        latents=None,
    ):
        assert len(image.shape) == 4

        image = image.to(device=device, dtype=dtype)

        init_latent_dist = self.vae.encode(image).latent_dist
        init_latents = init_latent_dist.sample(generator=generator)
        init_latents = self.vae.config.scaling_factor * init_latents

        assert len(init_latents.shape) == 4

        init_latents = init_latents.unsqueeze(0)
        init_latents = torch.cat([init_latents] * num_frames, dim=1)
        init_latents = torch.cat([init_latents] * batch_size, dim=0)
        init_latents_orig = init_latents

        shape = (
            batch_size,
            num_frames,
            num_channels_latents // 2,
            height // self.vae_scale_factor,
            width // self.vae_scale_factor,
        )

        if init_latents.shape != shape:
            raise ValueError(
                f"Init latent shape {init_latents.shape} != {shape}"
            )

        # add noise to latents using the timesteps
        noise = randn_tensor(init_latents.shape,
                             generator=generator, device=device, dtype=dtype)
        latents = self.scheduler.add_noise(init_latents, noise, timestep)

        return latents, init_latents_orig, noise

    def prepare_latents(
        self,
        batch_size,
        num_frames,
        num_channels_latents,
        height,
        width,
        dtype,
        device,
        generator,
        latents=None,
    ):
        shape = (
            batch_size,
            num_frames,
            num_channels_latents // 2,
            height // self.vae_scale_factor,
            width // self.vae_scale_factor,
        )
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if latents is None:
            latents = randn_tensor(
                shape, generator=generator, device=device, dtype=dtype)
        else:
            latents = latents.to(device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma
        return latents

    @torch.no_grad()
    def __call__(
        self,
        image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
        height: int = 576,
        width: int = 1024,
        num_frames: Optional[int] = None,
        num_inference_steps: int = 25,
        min_guidance_scale: float = 1.0,
        max_guidance_scale: float = 3.0,
        fps: int = 7,
        motion_bucket_id: int = 127,
        noise_aug_strength: int = 0.02,
        decode_chunk_size: Optional[int] = None,
        num_videos_per_prompt: Optional[int] = 1,
        generator: Optional[Union[torch.Generator,
                                  List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        overlay_init_image=None,
        overlay_mask_image=None,
        overlay_start=0.0,
        overlay_end=0.5,
    ):
        r"""
        The call function to the pipeline for generation.

        Args:
            image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
                Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
                [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The width in pixels of the generated image.
            num_frames (`int`, *optional*):
                The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
            num_inference_steps (`int`, *optional*, defaults to 25):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference. This parameter is modulated by `strength`.
            min_guidance_scale (`float`, *optional*, defaults to 1.0):
                The minimum guidance scale. Used for the classifier free guidance with first frame.
            max_guidance_scale (`float`, *optional*, defaults to 3.0):
                The maximum guidance scale. Used for the classifier free guidance with last frame.
            fps (`int`, *optional*, defaults to 7):
                Frames per second. The rate at which the generated images shall be exported to a video after generation.
                Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
            motion_bucket_id (`int`, *optional*, defaults to 127):
                The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
            noise_aug_strength (`int`, *optional*, defaults to 0.02):
                The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
            decode_chunk_size (`int`, *optional*):
                The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
                between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
                for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor is generated by sampling using the supplied random `generator`.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.

        Returns:
            [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` is returned where the first element is a list of list with the generated frames.

        Examples:

        ```py
        from diffusers import StableVideoDiffusionInpaintingPipeline
        from diffusers.utils import load_image, export_to_video

        pipe = StableVideoDiffusionInpaintingPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
        pipe.to("cuda")

        image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
        image = image.resize((1024, 576))

        frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
        export_to_video(frames, "generated.mp4", fps=7)
    """
    # 0. Default height and width to unet
    height = height or self.unet.config.sample_size * self.vae_scale_factor
    width = width or self.unet.config.sample_size * self.vae_scale_factor

    use_overlay = overlay_init_image is not None and overlay_mask_image is not None

    num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
    decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames

    # 1. Check inputs. Raise error if not correct
    self.check_inputs(image, height, width)

    # 2. Define call parameters
    if isinstance(image, PIL.Image.Image):
        batch_size = 1
    elif isinstance(image, list):
        batch_size = len(image)
    else:
        batch_size = image.shape[0]
    device = self._execution_device
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = max_guidance_scale > 1.0

    # 3. Encode input image
    image_embeddings = self._encode_image(
        image, device, num_videos_per_prompt, do_classifier_free_guidance)

    # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
    # is why it is reduced here.
    # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
    fps = fps - 1

    # 4. Encode input image using VAE
    image = self.image_processor.preprocess(
        image, height=height, width=width)
    image = image + noise_aug_strength * torch.randn_like(image)

    needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
    if needs_upcasting:
        self.vae.to(dtype=torch.float32)

    image_latents = self._encode_vae_image(
        image, device, num_videos_per_prompt, do_classifier_free_guidance)
    image_latents = image_latents.to(image_embeddings.dtype)

    # cast back to fp16 if needed
    if needs_upcasting:
        self.vae.to(dtype=torch.float16)

    # Repeat the image latents for each frame so we can concatenate them with the noise
    # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
    image_latents = image_latents.unsqueeze(
        1).repeat(1, num_frames, 1, 1, 1)

    # 5. Get Added Time IDs
    added_time_ids = self._get_add_time_ids(
        fps,
        motion_bucket_id,
        noise_aug_strength,
        image_embeddings.dtype,
        batch_size,
        num_videos_per_prompt,
        do_classifier_free_guidance,
    )
    added_time_ids = added_time_ids.to(device)

    # 4. Prepare timesteps
    self.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = self.scheduler.timesteps

    latent_timestep = timesteps[:1].repeat(
        batch_size * num_videos_per_prompt
    )

    # 5. Prepare latent variables
    num_channels_latents = self.unet.config.in_channels
    latents = self.prepare_latents(
        batch_size * num_videos_per_prompt,
        num_frames,
        num_channels_latents,
        height,
        width,
        image_embeddings.dtype,
        device,
        generator,
        latents,
    )

    init_latents_orig = None
    noise = None
    overlay_mask = None
    if use_overlay:
        overlay_init_image = preprocess_image(
            image=overlay_init_image,
            batch_size=batch_size,
        ) if not isinstance(
            overlay_init_image, torch.FloatTensor
        ) else overlay_init_image

        overlay_mask_image = preprocess_mask(
            overlay_mask_image,
            batch_size=batch_size,
            num_frames=num_frames,
            scale_factor=self.vae_scale_factor,
        )

        overlay_mask = overlay_mask_image.to(
            device=device, dtype=latents.dtype)
        overlay_mask = torch.cat([overlay_mask] * num_videos_per_prompt)

        _, init_latents_orig, noise = self._prepare_latents_with_init_image(
            image=overlay_init_image,
            timestep=latent_timestep,
            num_frames=num_frames,
            num_channels_latents=num_channels_latents,
            batch_size=batch_size * num_videos_per_prompt,
            dtype=image_embeddings.dtype,
            device=device,
            width=width,
            height=height,
            generator=generator,
        )

        if init_latents_orig.shape != latents.shape:
            raise ValueError(
                f"Init latents shape {init_latents_orig.shape} != {latents.shape}"
            )

    # 7. Prepare guidance scale
    guidance_scale = torch.linspace(
        min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
    guidance_scale = guidance_scale.to(device, latents.dtype)
    guidance_scale = guidance_scale.repeat(
        batch_size * num_videos_per_prompt, 1)
    guidance_scale = _append_dims(guidance_scale, latents.ndim)

    overlay_start_index = int(overlay_start * num_inference_steps)
    overlay_end_index = int(overlay_end * num_inference_steps)

    # 8. Denoising loop
    num_warmup_steps = len(timesteps) - \
        num_inference_steps * self.scheduler.order
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat(
                [latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(
                latent_model_input, t)

            # Concatenate image_latents over channels dimention
            latent_model_input = torch.cat(
                [latent_model_input, image_latents], dim=2)

            # predict the noise residual
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=image_embeddings,
                added_time_ids=added_time_ids,
                return_dict=False,
            )[0]

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * \
                    (noise_pred_cond - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(
                noise_pred, t, latents).prev_sample

            if use_overlay and i >= overlay_start_index and i < overlay_end_index:

                init_latents_proper = self.scheduler.add_noise(
                    init_latents_orig, noise, torch.tensor([t])
                )
                latents = (init_latents_proper * overlay_mask) + \
                    (latents * (1 - overlay_mask)
                     )

            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()

    if not output_type == "latent":
        # cast back to fp16 if needed
        if needs_upcasting:
            self.vae.to(dtype=torch.float16)
        frames = self.decode_latents(
            latents, num_frames, decode_chunk_size)
        frames = tensor2vid(frames, self.image_processor,
                            output_type=output_type)
    else:
        frames = latents

    self.maybe_free_model_hooks()

    if not return_dict:
        return frames

    return StableVideoDiffusionPipelineOutput(frames=frames)

Running this function on the following test image gives the incorrect result. In particular, the model does not seem to denoise the overlaid latents at all.

Test image:
![test-image](https://github.com/huggingface/diffusers/assets/16655650/40b31611-3e89-443a-9c98-606d981c07e4)
Test mask image:
![test-mask](https://github.com/huggingface/diffusers/assets/16655650/287342cf-f6ad-4fa2-8d32-ada2c3385c45)

Problematic result:
![error](https://github.com/huggingface/diffusers/assets/16655650/06d3e7fb-78f4-463b-9539-e2219a1a72a0)

Test script:

```python

svd_pipeline = StableVideoDiffusionInpaintingPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid-xt", 
    torch_dtype=torch.float16,
    variant="fp16",
).to('cuda')

output = svd_pipeline(
    image=image,
    motion_bucket_id=30,
    overlay_init_image=image,
    overlay_mask_image=mask_image,
    overlay_end=0.5,
    generator=torch.Generator().manual_seed(42),
)

Btw setting overlay_end to 1.0 the output would still look a bit noisy, which probably means that the latent input to the unet is incorrect.

Thank you again for your help! This new model is very exciting to play with!

Open source status

Provide useful links for the implementation

No response

sayakpaul commented 11 months ago

Cc: @patrickvonplaten @patil-suraj @DN6

patrickvonplaten commented 11 months ago

Think I ran into the same issues here: https://github.com/huggingface/diffusers/pull/6003 :sweat_smile:

caoandong commented 11 months ago

Do you think the issue has something to do the the noise scheduler? If so then training the SVD model using the euler discrete scheduler won't work either. Will dig into the issue more when I'm free.

P.S. The SVD paper description about the modification to the scheduler:

image

caoandong commented 11 months ago

@patrickvonplaten Oh actually just saw this implementation.

caoandong commented 11 months ago

@patrickvonplaten Oh actually just saw this implementation.

nvm the pipeline suffered from the same issue. The example worked because the number of inference steps is set to 100 and the inpainting is ran the whole way through; but inspecting the intermediate steps shows that the model fails to denoise the latent. Something is clearly wrong with the added noise.

patrickvonplaten commented 11 months ago

Yes I'm also quite sure that the added noise is incorrect here - @patil-suraj @DN6 can you check?

jfischoff commented 11 months ago

@caoandong what do you mean by "inspecting the intermediate steps shows that the model fails to denoise the latent"?

CiaraStrawberry commented 11 months ago

Replace

        if use_overlay and i >= overlay_start_index and i < overlay_end_index:
                    init_latents_proper = self.scheduler.add_noise(
                        init_latents_orig, noise, torch.tensor([t])
                    )
                    latents = (init_latents_proper * overlay_mask) + \
                        (latents * (1 - overlay_mask)
                         )

with

            if use_overlay and i >= overlay_start_index and i < overlay_end_index:
                noise_timestep = timesteps[i + 1]
                init_latents_proper = self.scheduler.add_noise(
                    init_latents_orig, noise, torch.tensor([noise_timestep])
                )
                latents = (init_latents_proper * (overlay_mask)) + latents * (1 - overlay_mask)

`

https://github.com/huggingface/diffusers/assets/13116982/c025deaf-a187-42b2-b6cb-ebfbd0aef03d

jfischoff commented 11 months ago

@CiaraStrawberry won't this index exceed the length of the timesteps on the last step?

noise_timestep = timesteps[i + 1]

I just avoided the inpainting on the last step and it does seem to work!

Without:

With @CiaraStrawberry's suggested change

CiaraStrawberry commented 11 months ago

mine was with the overlay window before so it never hit the last timestep, so thumbs up for just skipping it when it does, that's what the sd inpainting pipelines do.

github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.