the value of trimap is not constrained between 0/1/2, should I apply torch.argmax() to it?
If apply torch.argmax() to trimap, the code for net_M will be explainable:
# why bg/fg/unsure value is not constrained between 0/1?
bg, fg, unsure = torch.split(trimap, 1, dim=1)
# why trimap value is not constrained between 0/1/2?
m_net_input = torch.cat((input, trimap), 1)
alpha_r = self.m_net(m_net_input)
alpha_p = fg + unsure * alpha_r
For the code in
network.py
:the value of
trimap
is not constrained between 0/1/2, should I applytorch.argmax()
to it?If apply
torch.argmax()
totrimap
, the code fornet_M
will be explainable:Is my understanding correct?