Open wubowen416 opened 7 months ago
Hello,
Thanks for the nice findings and for your interest in my work.
I think the goal is to rescale the values to a range that works well for the choice timestep embedding. For consistency models, they do something quite similar to edm but then scale by 1000, this is because they use sinusoidal embeddings for the timestep. In the case of edm as they use fourier embeddings the output of the rescaling is values close to [-1, 1]
.
It's a common practice in deep learning to rescale values such that they are within a small range, but in our case we use the raw values in the range [0.02, 80.0]
. This is not something I have experimented with and I don't know how it would impact the performance of the model. If you do manage to experiment with it kindly share your findings.
Thank you for your reply and your clear explanation.
I personally found that gradient will sometimes explode, causing the network to output nan, if rescaling is not properly applied (e.g., Song's rescale + Fourier embedding, or no scale + fourier embedding). This is especially severe when using more time steps, i.e., dividing the trajactory more. Based on your explanation, this is expected since the value range may be too large.
Maybe this is related to your experimental findings in the notebook, where you say that time step of 10 yielded better result. I think it is worth trying rescaling + larger time steps.
Anyway, thanks again for your kind response.
Thank you for taking the time to run experiments and for sharing your findings.
I'll open a PR for this.
Hi, nice repo, really appreciate it.
One thing is that in the implementation of Song's consistency models, before inputting sigmas in to the network, there is a rescale:
rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
You can chekck it here: https://github.com/openai/consistency_models/blob/e32b69ee436d518377db86fb2127a3972d0d8716/cm/karras_diffusion.py#L346C58-L346C58Similarly, in EDM's implementation, there is also a rescale before inputting sigma to the network.
c_noise = sigma.log() / 4
The link: https://github.com/NVlabs/edm/blob/62072d2612c7da05165d6233d13d17d71f213fee/training/networks.py#L663C9-L663C34But I did not find this rescaling in your implementation.
I am aware of that the code for improved consistency model has not released yet, so we really do not know if there is such an operation, what do you think?