Closed Dr-Yao-Yao closed 2 months ago
感谢开源关于GITS的相关代码,这个工作非常有意思和也很有价值。
我在研读代码的时候,发现一个可能的小问题:在GITS的DP搜索中,对cost_mat的计算在every round都会重置(ref: https://github.com/zju-pi/diff-sampler/blob/30f76572f8476a95b2fdf9b556e2df7761f88f2d/gits-main/gits_utils.py#L83),并没有累计num_accumulation_rounds,而dist.all_reduce(cost_mat) 操作似乎只能合并不同device之间的batch,但不能合并不同rounds之间的 (ref: https://github.com/zju-pi/diff-sampler/blob/30f76572f8476a95b2fdf9b556e2df7761f88f2d/gits-main/gits_utils.py#L135)。
以上是特意设计的,还是一个小bug呢?请批评指正。
感谢指出!这是整理代码过程中的错误,不影响论文中的实验结果,已修复。
感谢开源关于GITS的相关代码,这个工作非常有意思和也很有价值。
我在研读代码的时候,发现一个可能的小问题:在GITS的DP搜索中,对cost_mat的计算在every round都会重置(ref: https://github.com/zju-pi/diff-sampler/blob/30f76572f8476a95b2fdf9b556e2df7761f88f2d/gits-main/gits_utils.py#L83),并没有累计num_accumulation_rounds,而dist.all_reduce(cost_mat) 操作似乎只能合并不同device之间的batch,但不能合并不同rounds之间的 (ref: https://github.com/zju-pi/diff-sampler/blob/30f76572f8476a95b2fdf9b556e2df7761f88f2d/gits-main/gits_utils.py#L135)。
以上是特意设计的,还是一个小bug呢?请批评指正。