MarcoForte / FBA_Matting

Official repository for the paper F, B, Alpha Matting
MIT License
464 stars 95 forks source link

A bug in the code, influencing the training when batch size > 1 #7

Closed xymsh closed 4 years ago

xymsh commented 4 years ago

https://github.com/MarcoForte/FBA-Matting/blob/76751dd752d4d3b40bf58c64185fc77c0195cbeb/networks/models.py#L263

Hi, I think you forgot to set the "keepdim" parameter to True in the "torch.sum()" operations. The correct one should be alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la)

Without keeping dim, the output alpha size would become [batch, batch, height, width], due to the wrong broadcast, while it is supposed to be [batch, 1, height, width].

Apparently, when batch size > 1, the size of alpha prediction is not correct. Thus the loss calculation would be negatively influenced because the alpha, fg, and bg predictions are concatenated together in the last step of forward process. https://github.com/MarcoForte/FBA-Matting/blob/76751dd752d4d3b40bf58c64185fc77c0195cbeb/networks/models.py#L361

For example, we set the batch size 4. With this bug, we get an alpha prediction of size [4, 4, height, width]. The fg and bg prediction are both of size [4, 3, height, width]. After concatenation, we get an output with size [4, 10, height, width], instead of the supposed [4, 7, height, width].

When calculating loss, we would slice the output by indices to extract the alpha, fg, and bg predictions.

alpha_pred = output[:, 0:1, :, :] fg_pred = output[:, 1:4, :, :] bg_pred = output[:, 4:7, :, :]

Here fg_pred is actually part of the alpha_pred, because the first 4 channels are alpha prediction, instead of only 1 channel. Same to bg_pred. The loss for fg and bg predictions is meaningless here.

My experimental results for models using batch size > 1 proved this bug. The errors are extremely high. I'm wondering if this bug has a negative influence on your experiments.

MarcoForte commented 4 years ago

Hi thanks for pointing this out I will update the code to your implementation. I did not use the fba fusion for my models with batch-size above one and I also did not use it during training. So my way of coding it should not influence the results negatively.