megvii-research / BBN

The official PyTorch implementation of paper BBN: Bilateral-Branch Network with Cumulative Learning for Long-Tailed Visual Recognition
https://arxiv.org/abs/1912.02413
MIT License
657 stars 100 forks source link

Confusions about reverse sampler #13

Open Hwang64 opened 4 years ago

Hwang64 commented 4 years ago

The second step of data sampler is "Randomly sample a class according to Pi;". If a random sample manner is implemented, it is seems that there is no use to calculate Pi for each category. According to the code from https://github.com/Megvii-Nanjing/BBN/blob/7992e908842f5934f0d1ee3f430d796621e81975/lib/dataset/imbalance_cifar.py#L59, I think each category has equal probability to be select to train and it can't be describe as a "reverse sampling". Is there any misunderstanding? Thanks for your reply

John-Yao commented 3 years ago

@ZhouBoyan @Hwang64 Hi, I also notice that self.cfg.TRAIN.SAMPLER.DUAL_SAMPLER.TYPE == "reverse" cannot describe as reverse sampling There are some code to show the frequency of each class after "reverse sampling". And Could you point out what is the ignored information for us. Thanks!

import random num_classes = 10 class_nsamples = [100+int(random.uniform(-1,1)*30) for i in range(num_classes)] class_weights = [np.max(class_nsamples)/class_nsamples[i] for i in range(num_classes)] class_weights = sorted(class_weights) sum_weights = np.sum(class_weights)

print(class_nsamples)

# print(class_weights)
sampled_examples = []
for _ in range(np.sum(class_nsamples)):
    rand_number, now_sum = random.random() * sum_weights, 0
    for i in range(num_classes):
        now_sum += class_weights[i]
        if rand_number <= now_sum:
            # print(i)
            sampled_examples.append(i)
            break
_, class_nresamples = np.unique(sampled_examples, return_counts=True)
print('==> class frequency in actual data(origin)')
print([x/np.sum(class_weights) for x in class_weights])
print('==> class frequency in resample data(reverse)')
print([x/np.sum(class_nresamples) for x in class_nresamples])

==> class frequency in actual data(origin) [0.08581392013475014, 0.08887870299670551, 0.08967941203271185, 0.09132490583147722, 0.09759230132971584, 0.09954414735631015, 0.10054964379425269, 0.11060460817367797, 0.11311834926853427, 0.12289400908186439] ==> class frequency in resample data(reverse) [0.07837301587301587, 0.07738095238095238, 0.09126984126984126, 0.09821428571428571, 0.08928571428571429, 0.09424603174603174, 0.11507936507936507, 0.11607142857142858, 0.12103174603174603, 0.11904761904761904]

John-Yao commented 3 years ago

There are some update about the codes. It seem that the revevse sample is implemented correctly.

import random
num_classes = 10
class_nsamples = [1000+int(random.uniform(-1,1)*600) for i in range(num_classes)]
class_weights = [np.max(class_nsamples)/class_nsamples[i] for i in range(num_classes)]
sum_weights = np.sum(class_weights)
sampled_examples = []
for _ in range(np.sum(class_nsamples)):
    rand_number, now_sum = random.random() * sum_weights, 0
    for i in range(num_classes):
        now_sum += class_weights[i]
        if rand_number <= now_sum:
            sampled_examples.append(i)
            break
_, class_nresamples = np.unique(sampled_examples, return_counts=True)
print('==> class samples in actual data(origin)')
print(class_nsamples)
print('==> class frequency in actual data(origin)')
print([x/np.sum(class_nsamples) for x in class_nsamples])
print('==> class frequency in resample data(reverse)')
print([x/np.sum(class_nresamples) for x in class_nresamples])

The log is shown here:

==> class samples in actual data(origin) [966, 1190, 1152, 594, 1131, 1173, 1332, 558, 690, 1487] ==> class frequency in actual data(origin) [0.09403290178136864, 0.11583763262922224, 0.11213861578896135, 0.0578214737661832, 0.11009442227197508, 0.11418280930594762, 0.12966027450598658, 0.05431714202277815, 0.0671663584152633, 0.14474836951231385] ==> class frequency in resample data(reverse) [0.09617443784678283, 0.08040494500146014, 0.08089165774359973, 0.16003114961549694, 0.08313053635744183, 0.07641390051591551, 0.06434342451085369, 0.1604205198092086, 0.13949187189720627, 0.058697556702034456]