Open GuardSkill opened 6 years ago
maybe should be these?`` def total_variation_loss(image,mask): hole_mask = 1-mask loss = torch.sum(torch.abs(hole_mask[:, :, :, :-1](image[:, :, :, 1:] - image[:, :, :, :-1]))) + \ torch.sum(hole_mask[:, :, :-1:, :](torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) return loss
More seriously, it should be these code rather than above (Above code didn't consider the uppest/leftest dilated pixel minus operation)
def dialation_holes(hole_mask):
b, ch, h, w = hole_mask.shape
dilation_conv = nn.Conv2d(ch, ch, 3, padding=1, bias=False).to(device)
torch.nn.init.constant_(dilation_conv.weight, 1.0)
with torch.no_grad():
output_mask = dilation_conv(hole_mask)
updated_holes = output_mask != 0
return updated_holes.float()
def total_variation_loss(image,mask):
hole_mask = 1-mask
dilated_holes=dialation_holes(hole_mask)
colomns_in_Pset=dilated_holes[:, :, :, 1:] * dilated_holes[:, :, :, :-1]
rows_in_Pset=dilated_holes[:, :, 1:, :] * dilated_holes[:, :, :-1:, :]
loss = torch.sum(torch.abs(colomns_in_Pset*(image[:, :, :, 1:] - image[:, :, :, :-1]))) + \
torch.sum(torch.abs(rows_in_Pset*(image[:, :, :1 :] - image[:, :, -1:, :])))
return loss
Have you tried the code which you thought it should be? Have it brought any improvement to the result compared to the github author's ?
Hi ! @GuardSkill shouldn't it be mean instead of sum in the total_variation_loss function ?
loss = torch.mean(torch.abs(colomns_in_Pset*(image[:, :, :, 1:] - image[:, :, :, :-1]))) + \ torch.mean(torch.abs(rows_in_Pset*(image[:, :, :1 :] - image[:, :, -1:, :])))
I would argue that, while not the same exact loss as the one proposed in the paper (L_tv), the total_variation_loss()
implemented here should behave in just the same way.
Both total_variation_loss()
and L_tv are computed on I_comp (output_comp
in the code) and not I_out (output
), which contains:
input
) in the unmasked partSince the ground truth image does not change with I_out, it means that all 1-pixel shifts outside of the mask will always result in the same total variation, outside of the masked region. Inside of the masked region (as well as around the 1-pixel dilation of the mask) the TV loss will instead depend on I_out.
As such, the loss implemented here is L_tv + constant: the two functions thus share the same gradient.
It also seems to me that the current implementation is slightly more efficient, as it does not require computing the dilated mask, nor mask the image.
I find this does not conform to the original paper’s method, I think the sum of the abs value should be taken into the Loss(tv), and the tv loss is not the global difference of the whole picture, it just around the hole areas (P is the region of 1-pixel dilation of the hole region).