zju-pi / diff-sampler

An open-source toolbox for fast sampling of diffusion models. Official implementations for our [CVPR-2024, ICML-2024] papers
Apache License 2.0
165 stars 8 forks source link

关于GITS做DP过程的cost_mat计算的一点疑问 #11

Closed Dr-Yao-Yao closed 2 months ago

Dr-Yao-Yao commented 2 months ago

感谢开源关于GITS的相关代码,这个工作非常有意思和也很有价值。

我在研读代码的时候,发现一个可能的小问题:在GITS的DP搜索中,对cost_mat的计算在every round都会重置(ref: https://github.com/zju-pi/diff-sampler/blob/30f76572f8476a95b2fdf9b556e2df7761f88f2d/gits-main/gits_utils.py#L83),并没有累计num_accumulation_rounds,而dist.all_reduce(cost_mat) 操作似乎只能合并不同device之间的batch,但不能合并不同rounds之间的 (ref: https://github.com/zju-pi/diff-sampler/blob/30f76572f8476a95b2fdf9b556e2df7761f88f2d/gits-main/gits_utils.py#L135)。

以上是特意设计的,还是一个小bug呢?请批评指正。

zhyzhouu commented 2 months ago

感谢指出!这是整理代码过程中的错误,不影响论文中的实验结果,已修复。