jychoi118 / P2-weighting

CVPR 2022
MIT License
143 stars 14 forks source link

The shape of weights curve #15

Open markub3327 opened 1 year ago

markub3327 commented 1 year ago

Hello @jychoi118,

I read your paper and code and I can not find the relationship between the shape of the weights curve in section D. and the equation in your code: 1 / (self.p2_k + self.snr)**self.p2_gamma.

For example, I'm using the cosine scheduler with 1000 timesteps. The SNR is calculated as snr = 1.0 / (1 - self.alphas_cumprod) - 1. And then the weights are calculated as 1 / (1.0 + self.snr)**1.0.

This is a chart of SNR as a function of diffusion steps: SNR

This is a chart of unnormalized weights as a function of diffusion steps: P2_T

This is a chart of unnormalized weights as a function of signal-to-noise ratio (SNR): P2_SNR

This is an original chart of unnormalized weights from Figure A.:

Snímka obrazovky 2023-09-05 o 10 53 47

The full code:

import numpy as np
import matplotlib.pyplot as plt

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = np.linspace(0.0, timesteps, steps)
    alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return betas

class NoiseScheduler():
    def __init__(self, timesteps):
        super().__init__()
        self.timesteps = timesteps

        betas = cosine_beta_schedule(timesteps)

        alphas = 1.0 - betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.snr = 1.0 / (1 - self.alphas_cumprod) - 1
        self.P2_weights = 1 / (1.0 + self.snr)**1.0

noise_scheduler = NoiseScheduler(timesteps=1000)
plt.plot(np.arange(noise_scheduler.timesteps), noise_scheduler.snr)
plt.show()

plt.plot(np.arange(noise_scheduler.timesteps), noise_scheduler.P2_weights)
plt.show()

plt.plot(noise_scheduler.snr, noise_scheduler.P2_weights)
plt.show()

Thanks.

jychoi118 commented 1 year ago

The weights drawn in our plots are VLB loss. For example, the blue lines in Figure A are lambdas in equation (7) of our paper. What you have drawn is the relative weight of lambda' and lambda (equation (8)). Sorry for late reply.