if mask is not None: # here is the problem!! flat_input and flat_target are already made one-hot, thus the multiplication will not work!
mask = mask.float()
flat_input = flat_input * mask
flat_target = flat_target * mask
else:
mask = torch.ones_like(target)
An easy fix is the following:
if mask is not None:
mask = mask.float()
flat_input = (flat_input.t() * mask).t()
flat_target = (flat_target.t() * mask).t()
else:
mask = torch.ones_like(target)
Hello,
First of all, cool work! :)
Now let me get to the point:
I found the following bug in your code:
An easy fix is the following: