zhangguiwei610 / CAMEL

Official implementation for the CVPR 2024 paper CAMEL
Apache License 2.0
12 stars 0 forks source link

No code was found in the validation_pipeline in the verification section regarding DDIM inversion #1

Open Lethobenthos20 opened 2 months ago

Lethobenthos20 commented 2 months ago

Didn't find the code for DDIM inversion, but DDIM sampling directly from noise, why is that?

@torch.no_grad() def call( self, prompt: Union[str, List[str]], motion_prompt: None, video_length: Optional[int], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, **kwargs, ):

Default height and width to unet

    height = height or self.unet.config.sample_size * self.vae_scale_factor
    width = width or self.unet.config.sample_size * self.vae_scale_factor

    # Check inputs. Raise error if not correct
    self.check_inputs(prompt, height, width, callback_steps)

    # Define call parameters
    batch_size = 1 if isinstance(prompt, str) else len(prompt)
    device = self._execution_device
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = guidance_scale > 1.0

    # Encode input prompt
    text_embeddings = self._encode_prompt(
        prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # Prepare timesteps
    self.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = self.scheduler.timesteps

    # Prepare latent variables
    num_channels_latents = self.unet.in_channels
    latents = self.prepare_latents(
        batch_size * num_videos_per_prompt,
        num_channels_latents,
        video_length,
        height,
        width,
        text_embeddings.dtype,
        device,
        generator,
        latents,
    )
    latents_dtype = latents.dtype

    # Prepare extra step kwargs.
    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

    # Denoising loop
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual

            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=[text_embeddings,motion_prompt]).sample.to(dtype=latents_dtype)
            # noise_pred = self.unet(latent_model_input, t,
            #                        encoder_hidden_states=[text_embeddings, text_embeddings[2]]).sample.to(
            #     dtype=latents_dtype)

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)

    # Post-processing
    video = self.decode_latents(latents)

    # Convert to tensor
    if output_type == "tensor":
        video = torch.from_numpy(video)

    if not return_dict:
        return video

    return TuneAVideoPipelineOutput(videos=video)
zhangguiwei610 commented 2 months ago

Thanks for your attention! Please refer to https://github.com/zhangguiwei610/CAMEL/blob/c8fe6f9c7240870ee7af2f41f667ef769026a25a/train_camel.py#L438

Lethobenthos20 commented 2 months ago

OK, Thank you, I have noticed it, but can you point out the relevant code on how "ddim_inv_latent" is used?

Lethobenthos20 commented 1 month ago

Why is this in the inference phase, where sampling in the code starts directly with the noise and not after the inversion? Because it is not seen that "ddim_inv_latent" is used in inference. Or is there something wrong with my understanding?

zhangguiwei610 commented 1 month ago

In https://github.com/zhangguiwei610/CAMEL/blob/c8fe6f9c7240870ee7af2f41f667ef769026a25a/train_camel.py#L444, we can see that the "ddim_inv_latent" is input into the "validation_pipeline" function, and In https://github.com/zhangguiwei610/CAMEL/blob/c8fe6f9c7240870ee7af2f41f667ef769026a25a/tuneavideo/pipelines/pipeline_tuneavideo.py#L298, since the latent is not None, the sampling starts directly with the ddim_inv_latent.

Lethobenthos20 commented 1 month ago

OK, thank you very much for your answer, I've got it. Thank you again.