huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.4k stars 5.43k forks source link

Code With Strange Logic in CogVideoX's Dynamic CFG #9641

Open immortalCO opened 1 month ago

immortalCO commented 1 month ago

Describe the bug

As shown at pipeline_cogvideox_image2video.py L778, pipeline_cogvideox_video2video.py L778, and pipeline_cogvideox.py L697, the dynamic CFG is calculated in this way:

self._guidance_scale = 1 + guidance_scale * (
    (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)

However:

I wonder: is this really the desired behavior of the CogVideoX pipeline? Shouldn't it be one of the following:

self._guidance_scale = 1 + guidance_scale * (
    # change t.item() to i, which is from 0 to num_inference_steps - 1
     (1 - math.cos(math.pi * ((num_inference_steps - i) / num_inference_steps) ** 5.0)) / 2 
)
self._guidance_scale = 1 + guidance_scale * (
    # change num_inference_steps  to self.scheduler.num_train_timesteps, which is 1000
    (1 - math.cos(math.pi * ((self.scheduler.num_train_timesteps - t.item()) / self.scheduler.num_train_timesteps) ** 5.0)) / 2 
  )

Both implementations will make the dynamic CFG like a cosine annealing.

Also, I think here 1 + guidance_scale * (...) should be 1 + (guidance_scale - 1) * (...), otherwise its value will be 1 ~ 1 + CFG instead of 1 ~ CFG.

Please check it and fix it if it is really a bug, thank you very much.

Reproduction

import torch
from diffusers import CogVideoXImageToVideoPipeline
from diffusers.utils import export_to_video, load_image

prompt = "A little girl is riding a bicycle at high speed. Focused, detailed, realistic."
image = load_image(image="input.jpg")
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
    "THUDM/CogVideoX-5b-I2V",
    torch_dtype=torch.bfloat16
)

pipe.enable_sequential_cpu_offload()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()

video = pipe(
    prompt=prompt,
    image=image,
    num_videos_per_prompt=1,
    num_inference_steps=50,
    num_frames=49,
    guidance_scale=6,
    generator=torch.Generator(device="cuda").manual_seed(42),
    use_dynamic_cfg=True, # Then you can print out the self._guidance_scale to see what happens.
).frames[0]

export_to_video(video, "output.mp4", fps=8)

Logs

# In the following setting, `guidance_scale=4` is passed.
10/11/2024 01:36:51 - INFO - root - Denoising 1/50: cfg = 1.645743587275726 
10/11/2024 01:36:55 - INFO - root - Denoising 2/50: cfg = 1.7717514333823159 
10/11/2024 01:36:59 - INFO - root - Denoising 3/50: cfg = 3.9871759160877414 
10/11/2024 01:37:04 - INFO - root - Denoising 4/50: cfg = 3.7101792115724193 
10/11/2024 01:37:08 - INFO - root - Denoising 5/50: cfg = 1.8940487645793973 
10/11/2024 01:37:13 - INFO - root - Denoising 6/50: cfg = 2.635970965321337 
10/11/2024 01:37:17 - INFO - root - Denoising 7/50: cfg = 1.0187988588703782 
10/11/2024 01:37:22 - INFO - root - Denoising 8/50: cfg = 2.5852000899340863 
10/11/2024 01:37:26 - INFO - root - Denoising 9/50: cfg = 1.3089873683577653 
10/11/2024 01:37:31 - INFO - root - Denoising 10/50: cfg = 3.9915635934173324 
10/11/2024 01:37:35 - INFO - root - Denoising 11/50: cfg = 1.0023944806862168 
10/11/2024 01:37:40 - INFO - root - Denoising 12/50: cfg = 1.935990650841663 
10/11/2024 01:37:44 - INFO - root - Denoising 13/50: cfg = 3.9884025377098555

System Info

This is a bug in the code agnostic to system.

Who can help?

@DN6 @a-r-r-o-w @zRzRzRzRzRzRzR

a-r-r-o-w commented 1 month ago

Hey, thanks for reporting! We've come across this issue as well. This comes from maintaining 1:1 implementations with the original CogVideo code base.

See this and this. I think @yiyixuxu was looking into this.

I think what you mention is correct and creates the intended cosine guidance schedule. cc @zRzRzRzRzRzRzR as well for verifying this

github-actions[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

a-r-r-o-w commented 1 week ago

Gentle ping to @yiyixuxu. I think we should fix this issue in our pipelines, even if it is incompatible with original implementation.