Closed Yi-Qi638 closed 3 years ago
You can use this implementation as a reference.
from collections import Counter
def get_eql_class_weights(lambda_):
class_weights = np.zeros(1000)
labels = []
with open('datasets/imagenet/annotations/ImageNet_LT_train.txt', 'r') as f:
for lidx, line in enumerate(f):
_, label = line.split()
label_count = Counter(labels)
for idx, (label, count) in enumerate(sorted(label_count.items(), key=lambda x: -x[1])):
class_weights[label] = 1 if count > lambda_ else 0
print('idx: {}, cls: {} img: {}, weight: {}'.format(idx, label, count, class_weights[label]))
return class_weights
def replace_masked_values(tensor, mask, replace_with):
assert tensor.dim() == mask.dim(), '{} vs {}'.format(tensor.shape, mask.shape)
one_minus_mask = 1 - mask
values_to_add = replace_with * one_minus_mask
return tensor * mask + values_to_add
class SoftmaxEQL(object):
def __init__(self, lambda_, ignore_prob):
self.lambda_ = lambda_
self.ignore_prob = ignore_prob
self.class_weight = torch.Tensor(get_eql_class_weights(self.lambda_)).cuda()
def __call__(self, input, target):
N, C = input.shape
not_ignored = self.class_weight.view(1, C).repeat(N, 1)
over_prob = (torch.rand(input.shape).cuda() > self.ignore_prob).float()
is_gt = target.new_zeros((N, C)).float()
is_gt[torch.arange(N), target] = 1
weights = ((not_ignored + over_prob + is_gt) > 0).float()
input = replace_masked_values(input, weights, -1e7)
loss = F.cross_entropy(input, target)
return loss
@tztztztztz hi, What do these two parameters mean , lambda_, ignore_prob?
@tztztztztz hi, What do these two parameters mean , lambda_, ignore_prob?
lambda_: to define what tail classes are. ignore_prob: The probability to drop a negative gradient of tail classes
why do you use -1e7? why not 0 to remove all tail logits?
Actually, I implemented my SEQL according to the paper, but I got the nan loss in training. So can I refer to the official method? Thank you