class CriterionPixelWise(nn.Module):
def __init__(self, ignore_index=255, use_weight=True, reduce=True):
super(CriterionPixelWise, self).__init__()
self.ignore_index = ignore_index
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduce=reduce)
if not reduce:
print("disabled the reduce.")
def forward(self, preds_S, preds_T):
preds_T[0].detach()
assert preds_S[0].shape == preds_T[0].shape,'the output dim of teacher and student differ'
N,C,W,H = preds_S[0].shape
softmax_pred_T = F.softmax(preds_T[0].permute(0,2,3,1).contiguous().view(-1,C), dim=1)
logsoftmax = nn.LogSoftmax(dim=1)
loss = (torch.sum( - softmax_pred_T * logsoftmax(preds_S[0].permute(0,2,3,1).contiguous().view(-1,C))))/W/H
return loss
As you have permuted the NxCxWxH tensor with (0,2,3,1), you got NxWxHxC, and view(-1,C) gave you NWHxC, then the softmax shoud be calculated along the last dim -1 , rather than dim1 , so is the LogSoftmax
As you have permuted the NxCxWxH tensor with (0,2,3,1), you got NxWxHxC, and view(-1,C) gave you NWHxC, then the softmax shoud be calculated along the last dim -1 , rather than dim1 , so is the LogSoftmax