Kinyugo / consistency_models

A mini-library for training consistency models.
https://arxiv.org/abs/2303.01469
MIT License
189 stars 20 forks source link

Rescale of sigmas #7

Open wubowen416 opened 7 months ago

wubowen416 commented 7 months ago

Hi, nice repo, really appreciate it.

One thing is that in the implementation of Song's consistency models, before inputting sigmas in to the network, there is a rescale: rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44) You can chekck it here: https://github.com/openai/consistency_models/blob/e32b69ee436d518377db86fb2127a3972d0d8716/cm/karras_diffusion.py#L346C58-L346C58

Similarly, in EDM's implementation, there is also a rescale before inputting sigma to the network. c_noise = sigma.log() / 4 The link: https://github.com/NVlabs/edm/blob/62072d2612c7da05165d6233d13d17d71f213fee/training/networks.py#L663C9-L663C34

But I did not find this rescaling in your implementation.

I am aware of that the code for improved consistency model has not released yet, so we really do not know if there is such an operation, what do you think?

Kinyugo commented 7 months ago

Hello,

Thanks for the nice findings and for your interest in my work. I think the goal is to rescale the values to a range that works well for the choice timestep embedding. For consistency models, they do something quite similar to edm but then scale by 1000, this is because they use sinusoidal embeddings for the timestep. In the case of edm as they use fourier embeddings the output of the rescaling is values close to [-1, 1].

It's a common practice in deep learning to rescale values such that they are within a small range, but in our case we use the raw values in the range [0.02, 80.0]. This is not something I have experimented with and I don't know how it would impact the performance of the model. If you do manage to experiment with it kindly share your findings.

rescaled_sigmas_cm rescaled_sigmas_edm

wubowen416 commented 7 months ago

Thank you for your reply and your clear explanation.

I personally found that gradient will sometimes explode, causing the network to output nan, if rescaling is not properly applied (e.g., Song's rescale + Fourier embedding, or no scale + fourier embedding). This is especially severe when using more time steps, i.e., dividing the trajactory more. Based on your explanation, this is expected since the value range may be too large.

Maybe this is related to your experimental findings in the notebook, where you say that time step of 10 yielded better result. I think it is worth trying rescaling + larger time steps.

Anyway, thanks again for your kind response.

Kinyugo commented 7 months ago

Thank you for taking the time to run experiments and for sharing your findings.

I'll open a PR for this.