LeiaLi / SRDiff

171 stars 19 forks source link

代码中 t_cond 是什么意思 #2

Open Sebastian970107 opened 1 year ago

Sebastian970107 commented 1 year ago
def q_sample(self, x_start, t, noise=None):
    noise = default(noise, lambda: torch.randn_like(x_start))
    t_cond = (t[:, None, None, None] >= 0).float()
    t = t.clamp_min(0)
    return (
                   extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                   extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
           ) * t_cond + x_start * (1 - t_cond)

想问一下代码中 “t_cond”是什么意思呢,扩散模型中似乎没有找到相关的数学表示。

whyandbecause commented 1 year ago

你好,请问你知道这个是什么意思了吗?

diamondxx commented 1 year ago

In brief, if t_cond >= 0, choose “extract(self.sqrt_alphas_cumprod, t, x_start.shape) x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) noise” if t_cond < 0, then choose x_start. I think it is a judgment about the bounds or outliers of timestep.