Open Twice22 opened 4 years ago
Well there actually is a keepdim = True
in the torch.sum
. In networks/models.py, the code is as follows, on line 256:
def fba_fusion(alpha, img, F, B):
F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B))
B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F)
F = torch.clamp(F, 0, 1)
B = torch.clamp(B, 0, 1)
la = 0.1
alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la)
alpha = torch.clamp(alpha, 0, 1)
return alpha, F, B
Oh ok. I haven't seen this because I was working on the implementation before your the last commit
Hi thanks for your interest and taking time to inform me of this issue. As raphychek pointed out it has been corrected already, see https://github.com/MarcoForte/FBA_Matting/issues/7
Hello!
Thank you for releasing your implementation. Yet, it looks like the fba_fusion doesn't do what you want to do. Or am I missing something?
Indeed, before calling the
fba_fusion
function, you've defined,alpha
,fg
,bg
as follow:So, you are broadcasting alpha so that it is of size (B, 1, H, W) Moreover F, and B are respectively of sizes (B, 3, H, W)
Now, if we look at how you compute
alpha
in thefba_fusion
module, we have:So, we have (by using the broadcasting rules)
size = ((B, 1, H, W) * scalar + sum((B, 3, H, W), 1)) / (sum((B, 3, H, W), 1) + scalar) size = (B, 1, H, W) + (B, H, W)) / (B, H, W) size = (B, 1, H, W) + (1, B, H, W) / (B, H, W) size = (B, B, H, W) / (B, B, H, W) size = (B, B, H, W)
So, in the end, alpha is of size (B, B, H, W)
Wheren't you supposed to add
keepdim=True
intorch.sum
? Your final pth model used this flawed operation?Hope you can reply my enquiries. Thank you