megvii-research / TransMVSNet

(CVPR 2022) TransMVSNet: Global Context-aware Multi-view Stereo Network with Transformers.
MIT License
266 stars 26 forks source link

Focal Loss里面我不是很明白γ的设置,似乎只是提高了一下entro_loss的倍数? #30

Open TOPthemaster opened 1 year ago

TOPthemaster commented 1 year ago

您好,我在阅读您的代码时,遇到了下面的问题:

def trans_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
    depth_loss_weights = kwargs.get("dlossw", None)
    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
    total_entropy =  torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)

    for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "stage" in k]:
        prob_volume = stage_inputs["prob_volume"]
        depth_values = stage_inputs["depth_values"]
        depth_gt = depth_gt_ms[stage_key]
        mask = mask_ms[stage_key]

        mask = mask > 0.5
        entropy_weight = 2.0

        entro_loss, depth_entropy = entropy_loss(prob_volume, depth_gt, mask, depth_values)
        entro_loss = entro_loss * entropy_weight
        depth_loss = F.smooth_l1_loss(depth_entropy[mask], depth_gt[mask], reduction='mean')
        total_entropy += entro_loss

        if depth_loss_weights is not None:
            stage_idx = int(stage_key.replace("stage", "")) - 1
            total_loss += depth_loss_weights[stage_idx] * entro_loss
        else:
            total_loss += entro_loss

    return total_loss, depth_loss, total_entropy, depth_entropy

如果我理解entropy_loss为交叉熵loss的话,那么entropy_weight应当类似于Focal Loss的样本权重,但似乎这个组合方式不太像Focal loss。是不是我理解错了? 如果您有空看到了该评论,希望能帮我解答一下该疑惑。非常感谢。

maxin0002 commented 1 year ago

楼主,我也是同样的疑惑,这里的depth_loss_weights应该指的是深度值的L1损失,但是最后却加了交叉熵损失