CleanDiffuserTeam / CleanDiffuser

CleanDiffuser: An Easy-to-use Modularized Library for Diffusion Models in Decision Making
Apache License 2.0
384 stars 35 forks source link

cosine_noise_schedule modifies the last dimension of t_diffusion #31

Closed Wu-Chenyang closed 1 month ago

Wu-Chenyang commented 1 month ago
def cosine_noise_schedule(t_diffusion: torch.Tensor, s: float = 0.008):
    t_diffusion[-1] = 0.9946
    alpha = (np.pi / 2.0 * (t_diffusion + s) / (1 + s)).cos() / np.cos(
        np.pi / 2.0 * s / (1 + s))
    sigma = (1.0 - alpha**2).sqrt()
    return alpha, sigma

I don't know what the reason for this modification is, but no matter what it is, it should have been modified somewhere else.

ZibinDong commented 1 month ago

Thanks so much for your thorough checking!

I also noticed and fixed this bug in the lightning branch.

I've been doing most of my updates on the lightning branch lately. I've rebuilt the entire codebase using PyTorch Lightning in this branch. It allows us to train models and use advanced deep learning techniques like parallel training and mixed precision in a much simpler way.

If you're working on developing new algorithms, I'd strongly recommend trying the lightning branch. I've also added some notebook tutorials to this branch to help you get up to speed with the PyTorch Lightning version of CleanDiffuser more quickly.

Also, I'll be fixing this bug in the main branch as soon as possible.

ZibinDong commented 1 month ago
def cosine_noise_schedule(t_diffusion: torch.Tensor, s: float = 0.008):
    alpha = (np.pi / 2.0 * ((t_diffusion).clip(0., 0.9946) + s) / (1 + s)).cos() / np.cos(
        np.pi / 2.0 * s / (1 + s))
    sigma = (1.0 - alpha**2).sqrt()
    return alpha, sigma

Let me explain where this 0.9946 comes from. It's from the last line on page 22 of the DPM-Solver paper (2206.00927 on arxiv.org). For cosine noise schedules, capping the maximum value of t at 0.9946 helps improve numerical stability. I've now simply truncated t_diffusion at 0.9946. This should fix the issues caused by setting t_diffusion[-1] = 0.9946 while minimizing any potential side effects.

I'm sorry for not checking the issues as quickly as I should lately - things have been pretty hectic. If you need more hands-on help with the code, feel free to add me on WeChat: dongzibin1112. You can reach out anytime and I'll do my best to help you sort things out.

Wu-Chenyang commented 1 month ago

I see. Thanks for your quick response! :)