vandit15 / Class-balanced-loss-pytorch

Pytorch implementation of the paper "Class-Balanced Loss Based on Effective Number of Samples"
MIT License
781 stars 120 forks source link

why sum weights? #15

Open mmxuan18 opened 3 years ago

mmxuan18 commented 3 years ago

why add this sum make all class weight the same image

mjkvaak commented 1 year ago

The construct above until the highlighted line is for selecting the weight per each label. For example, if your target label is (one-hot-encoded) class 1, you would use those lines to select the weight corresponding to that index. One could simplify the lines

weights = torch.tensor(weights).float()
weights = weights.unsqueeze(0)
weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
weights = weights.sum(1)

by simply saying

weights = weights[labels_one_hot.argmax(1)]  # pick the correct weight for each label
# weights = weights[labels.long()]  # this would also do

Here's a snippet you can use to verify:

# gen dummy labels
num_classes = 5
labels = torch.randint(num_classes, (100, ))
labels_one_hot = torch.eye(num_classes)[labels]
# gen dummy weights
weights = torch.randn(num_classes)
# original method
weights0 = weights.unsqueeze(0)
weights0 = weights0.repeat(labels_one_hot.shape[0],1) * labels_one_hot
weights0 = weights0.sum(1)
# method 1
weights1 = weights[labels_one_hot.argmax(1)]
# 
weights2 = weights[labels.long()]
print(torch.equal(weights0, weights1), torch.equal(weights0, weights2))
>> True True
GYDDHPY commented 2 months ago

This seems only useful for multi-class problem.