tztztztztz / eql.detectron2

The official implementation of Equalization Loss for Long-Tailed Object Recognition (CVPR 2020) based on Detectron2. https://arxiv.org/abs/2003.05176
Apache License 2.0
202 stars 17 forks source link

How to implement Softmax equalization loss #9

Closed Yi-Qi638 closed 3 years ago

Yi-Qi638 commented 4 years ago

Actually, I implemented my SEQL according to the paper, but I got the nan loss in training. So can I refer to the official method? Thank you

tztztztztz commented 4 years ago

You can use this implementation as a reference.

from collections import Counter

def get_eql_class_weights(lambda_):
    class_weights = np.zeros(1000)
    labels = []
    with open('datasets/imagenet/annotations/ImageNet_LT_train.txt', 'r') as f:
        for lidx, line in enumerate(f):
            _, label = line.split()
            labels.append(int(label))
    label_count = Counter(labels)
    for idx, (label, count) in enumerate(sorted(label_count.items(), key=lambda x: -x[1])):
        class_weights[label] = 1 if count > lambda_ else 0
        print('idx: {}, cls: {} img: {}, weight: {}'.format(idx, label, count, class_weights[label]))
    return class_weights

def replace_masked_values(tensor, mask, replace_with):
    assert tensor.dim() == mask.dim(), '{} vs {}'.format(tensor.shape, mask.shape)
    one_minus_mask = 1 - mask
    values_to_add = replace_with * one_minus_mask
    return tensor * mask + values_to_add

class SoftmaxEQL(object):
    def __init__(self, lambda_, ignore_prob):
        self.lambda_ = lambda_
        self.ignore_prob = ignore_prob
        self.class_weight = torch.Tensor(get_eql_class_weights(self.lambda_)).cuda()

    def __call__(self, input, target):
        N, C = input.shape
        not_ignored = self.class_weight.view(1, C).repeat(N, 1)
        over_prob = (torch.rand(input.shape).cuda() > self.ignore_prob).float()
        is_gt = target.new_zeros((N, C)).float()
        is_gt[torch.arange(N), target] = 1

        weights = ((not_ignored + over_prob + is_gt) > 0).float()
        input = replace_masked_values(input, weights, -1e7)
        loss = F.cross_entropy(input, target)
        return loss
bitwangdan commented 3 years ago

@tztztztztz hi, What do these two parameters mean , lambda_, ignore_prob?

tztztztztz commented 3 years ago

@tztztztztz hi, What do these two parameters mean , lambda_, ignore_prob?

lambda_: to define what tail classes are. ignore_prob: The probability to drop a negative gradient of tail classes

meijie0401 commented 3 years ago

why do you use -1e7? why not 0 to remove all tail logits?