MarcoForte / FBA_Matting

Official repository for the paper F, B, Alpha Matting
MIT License
464 stars 95 forks source link

Mistake in the code? #18

Open Twice22 opened 4 years ago

Twice22 commented 4 years ago

Hello!

Thank you for releasing your implementation. Yet, it looks like the fba_fusion doesn't do what you want to do. Or am I missing something?

Indeed, before calling the fba_fusion function, you've defined, alpha, fg, bg as follow:

        alpha = torch.clamp(output[:, 0][:, None], 0, 1)

        F = torch.sigmoid(output[:, 1:4])
        B = torch.sigmoid(output[:, 4:7])

        alpha, F, B = fba_fusion(alpha, img, F, B)

So, you are broadcasting alpha so that it is of size (B, 1, H, W) Moreover F, and B are respectively of sizes (B, 3, H, W)

Now, if we look at how you compute alpha in the fba_fusion module, we have:

def fba_fusion(alpha, img, F, B):
    F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B))
    B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F)

    F = torch.clamp(F, 0, 1)
    B = torch.clamp(B, 0, 1)
    la = 0.1
    alpha = (alpha * la + torch.sum((img - B) * (F - B), 1)) / (torch.sum((F - B) * (F - B), 1) + la)
    alpha = torch.clamp(alpha, 0, 1)
    return alpha, F, B

So, we have (by using the broadcasting rules)

size = ((B, 1, H, W) * scalar + sum((B, 3, H, W), 1)) / (sum((B, 3, H, W), 1) + scalar) size = (B, 1, H, W) + (B, H, W)) / (B, H, W) size = (B, 1, H, W) + (1, B, H, W) / (B, H, W) size = (B, B, H, W) / (B, B, H, W) size = (B, B, H, W)

So, in the end, alpha is of size (B, B, H, W)

Wheren't you supposed to add keepdim=True in torch.sum? Your final pth model used this flawed operation?

Hope you can reply my enquiries. Thank you

raphychek commented 4 years ago

Well there actually is a keepdim = True in the torch.sum. In networks/models.py, the code is as follows, on line 256:

def fba_fusion(alpha, img, F, B):
    F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B))
    B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F)

    F = torch.clamp(F, 0, 1)
    B = torch.clamp(B, 0, 1)
    la = 0.1
    alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la)
    alpha = torch.clamp(alpha, 0, 1)
    return alpha, F, B
Twice22 commented 4 years ago

Oh ok. I haven't seen this because I was working on the implementation before your the last commit

MarcoForte commented 4 years ago

Hi thanks for your interest and taking time to inform me of this issue. As raphychek pointed out it has been corrected already, see https://github.com/MarcoForte/FBA_Matting/issues/7