cswry / OSEDiff

[NeurlPS2024] One-Step Effective Diffusion Network for Real-World Image Super-Resolution
Apache License 2.0
215 stars 12 forks source link

Question about VSD loss #42

Closed Renzhihan closed 4 weeks ago

Renzhihan commented 1 month ago

感谢您的工作和开源代码,但在阅读VSD loss的过程中,我产生了一些疑惑,具体在osediff.py的第238-241行

weighting_factor = torch.abs(latents - x0_pred_fix).mean(dim=[1, 2, 3], keepdim=True)

grad = (x0_pred_update - x0_pred_fix) / weighting_factor
loss = F.mse_loss(latents, (latents - grad).detach())

在附录的第13步,正则loss的计算是两个噪声的差乘以一个权重wt。根据我的理解,此处的x0_pred_update - x0_pred_fix对应噪声的差,但我没有想明白weighting_factor所对应的公式。另外,权重wt在论文里似乎没有说明含义,在代码里我也没有找到是如何实现的,就是一个常数1吗?您可以具体解释一下weighting_factor和wt的含义吗,谢谢

cswry commented 1 month ago

你好,这里分布loss的实现我们是参考dmd的。vsd loss是在噪声域约束分布的,dmd将其转换为在x0处约束分布,虽然两者在理论上是等价的,但是我们发现在x0处约束sr任务的效果更好一些。具体算法步骤请参考我们最近更新的文章版本