naoto0804 / pytorch-inpainting-with-partial-conv

Unofficial pytorch implementation of 'Image Inpainting for Irregular Holes Using Partial Convolutions' [Liu+, ECCV2018]
MIT License
593 stars 136 forks source link

Is there a worng understand in total variation? #27

Open GuardSkill opened 6 years ago

GuardSkill commented 6 years ago

I find this does not conform to the original paper’s method, I think the sum of the abs value should be taken into the Loss(tv), and the tv loss is not the global difference of the whole picture, it just around the hole areas (P is the region of 1-pixel dilation of the hole region).

def total_variation_loss(image):
    # shift one pixel and get difference (for both x and y direction)
    loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
        torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
    return loss
GuardSkill commented 6 years ago

maybe should be these?`` def total_variation_loss(image,mask): hole_mask = 1-mask loss = torch.sum(torch.abs(hole_mask[:, :, :, :-1](image[:, :, :, 1:] - image[:, :, :, :-1]))) + \ torch.sum(hole_mask[:, :, :-1:, :](torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) return loss

GuardSkill commented 6 years ago

More seriously, it should be these code rather than above (Above code didn't consider the uppest/leftest dilated pixel minus operation)

def dialation_holes(hole_mask):
    b, ch, h, w = hole_mask.shape
    dilation_conv = nn.Conv2d(ch, ch, 3, padding=1, bias=False).to(device)
    torch.nn.init.constant_(dilation_conv.weight, 1.0)
    with torch.no_grad():
        output_mask = dilation_conv(hole_mask)
    updated_holes = output_mask != 0
    return updated_holes.float()

def total_variation_loss(image,mask):
    hole_mask = 1-mask
    dilated_holes=dialation_holes(hole_mask)
    colomns_in_Pset=dilated_holes[:, :, :, 1:] * dilated_holes[:, :, :, :-1]
    rows_in_Pset=dilated_holes[:, :, 1:, :] * dilated_holes[:, :, :-1:, :]
    loss = torch.sum(torch.abs(colomns_in_Pset*(image[:, :, :, 1:] - image[:, :, :, :-1]))) + \
        torch.sum(torch.abs(rows_in_Pset*(image[:, :, :1 :] - image[:, :, -1:, :])))
    return loss
Daisy007girl commented 5 years ago

Have you tried the code which you thought it should be? Have it brought any improvement to the result compared to the github author's ?

Xavier31 commented 3 years ago

Hi ! @GuardSkill shouldn't it be mean instead of sum in the total_variation_loss function ?

loss = torch.mean(torch.abs(colomns_in_Pset*(image[:, :, :, 1:] - image[:, :, :, :-1]))) + \ torch.mean(torch.abs(rows_in_Pset*(image[:, :, :1 :] - image[:, :, -1:, :])))

fgiobergia commented 2 years ago

I would argue that, while not the same exact loss as the one proposed in the paper (L_tv), the total_variation_loss() implemented here should behave in just the same way.

Both total_variation_loss() and L_tv are computed on I_comp (output_comp in the code) and not I_out (output), which contains:

Since the ground truth image does not change with I_out, it means that all 1-pixel shifts outside of the mask will always result in the same total variation, outside of the masked region. Inside of the masked region (as well as around the 1-pixel dilation of the mask) the TV loss will instead depend on I_out.

As such, the loss implemented here is L_tv + constant: the two functions thus share the same gradient.

It also seems to me that the current implementation is slightly more efficient, as it does not require computing the dilated mask, nor mask the image.