chen-judge / RGBTCrowdCounting

Official Implement of CVPR 2021 paper “Cross-Modal Collaborative Representation Learning and a Large-Scale RGBT Benchmark for Crowd Counting”
64 stars 11 forks source link

请问贝叶斯loss部分的代码是否有误 #7

Open gaiyi7788 opened 2 years ago

gaiyi7788 commented 2 years ago

之前提问过是否能使用多batch训练,您这边回复说没有尝试过,我读了一下贝叶斯loss的代码,发现以下问题:

class Bay_Loss(Module):
    def __init__(self, use_background, device):
        super(Bay_Loss, self).__init__()
        self.device = device
        self.use_bg = use_background

    def forward(self, prob_list, pre_density):
        loss = 0
        for idx, prob in enumerate(prob_list):  # iterative through each sample
            if prob is None:  # image contains no annotation points
                pre_count = torch.sum(pre_density[idx])
                target = torch.zeros((1,), dtype=torch.float32, device=self.device)
            else:
                N = len(prob)
                if self.use_bg:
                    target = torch.ones((N,), dtype=torch.float32, device=self.device)
                    target[-1] = 0.0  # the expectation count of background should be zero
                else:
                    target = torch.ones((N,), dtype=torch.float32, device=self.device)
                pre_count = torch.sum(pre_density[idx].view((1, -1)) * prob, dim=1)  # flatten into vector

        loss += torch.sum(torch.abs(target - pre_count))
        loss = loss / len(prob_list)
        return loss

最后几行我认为+=位置写错,少一个缩进,应改为

            loss += torch.sum(torch.abs(target - pre_count))
        loss = loss / len(prob_list)
        return loss

loss加和写在了循环外面,所以多batch无法正常训练。贝叶斯loss的源代码似乎不存在这个问题。

目前修改后我可以进行多batch训练,以上个人理解,望指正,感谢!