Closed Renzhihan closed 4 weeks ago
感谢您的工作和开源代码,但在阅读VSD loss的过程中,我产生了一些疑惑,具体在osediff.py的第238-241行
weighting_factor = torch.abs(latents - x0_pred_fix).mean(dim=[1, 2, 3], keepdim=True) grad = (x0_pred_update - x0_pred_fix) / weighting_factor loss = F.mse_loss(latents, (latents - grad).detach())
在附录的第13步,正则loss的计算是两个噪声的差乘以一个权重wt。根据我的理解,此处的x0_pred_update - x0_pred_fix对应噪声的差,但我没有想明白weighting_factor所对应的公式。另外,权重wt在论文里似乎没有说明含义,在代码里我也没有找到是如何实现的,就是一个常数1吗?您可以具体解释一下weighting_factor和wt的含义吗,谢谢
你好,这里分布loss的实现我们是参考dmd的。vsd loss是在噪声域约束分布的,dmd将其转换为在x0处约束分布,虽然两者在理论上是等价的,但是我们发现在x0处约束sr任务的效果更好一些。具体算法步骤请参考我们最近更新的文章版本
感谢您的工作和开源代码,但在阅读VSD loss的过程中,我产生了一些疑惑,具体在osediff.py的第238-241行
在附录的第13步,正则loss的计算是两个噪声的差乘以一个权重wt。根据我的理解,此处的x0_pred_update - x0_pred_fix对应噪声的差,但我没有想明白weighting_factor所对应的公式。另外,权重wt在论文里似乎没有说明含义,在代码里我也没有找到是如何实现的,就是一个常数1吗?您可以具体解释一下weighting_factor和wt的含义吗,谢谢