dontLoveBugs / DORN_pytorch

PyTorch implementation of Deep Ordinal Regression Network for Monocular Depth Estimation
299 stars 67 forks source link

BerHuber implementation #28

Open jrodriguezpuigvert opened 4 years ago

jrodriguezpuigvert commented 4 years ago

Hi, thank you for your contribution. Following the BerHuber loss that u implemented, I am not sure if that is the correct implementation. I would suggest something like this:

class berHuLoss(nn.Module): def init(self): super(berHuLoss, self).init()

def forward(self, pred, target):
    assert pred.dim() == target.dim(), "inconsistent dimensions"
    valid_mask = (target > 0).detach()
    x_abs = (pred - target).abs()
    x_abs = x_abs[valid_mask]
    c = 0.2 * torch.max(x_abs)
    loss = torch.where(x_abs > c, (x_abs ** 2 + c ** 2) / (2 * c), x_abs).mean()
    return loss

what do u think?