yjh0410 / CenterNet-plus

A Simple Baseline for Object Detection
55 stars 11 forks source link

Class imbalance #18

Open YashRunwal opened 3 years ago

YashRunwal commented 3 years ago

@yjh0410,

My dataset is highly imbalanced. The class count in the training dataset (2700 samples) is as follows: {'1': 306, '2': 60, '3': 133, '4': 14197, '6': 3886, '7': 69} The key is the class label and the value is the number of class samples in the dataset.

I came across the WeightedRandomSampler from PyTorch which assigns the highest weight to the class_id with the lowest number of samples.

I was wondering how would we integrate this with gt_tensor?

yjh0410 commented 3 years ago

Sorry, I haven't try the WeightedRandomSampler from PyTorch, but I don't think there is relation between the Sampler and the gt_tensor from gt_creator function.

YashRunwal commented 3 years ago

@yjh0410 I can get the weights for different classes. It is not an issue. But how do we assign these weights to our classes inside the gt_tensor? Example: https://discuss.pytorch.org/t/some-problems-with-weightedrandomsampler/23242/40?u=duddal

YashRunwal commented 3 years ago

@yjh0410 How would you suggest addressing the class imbalance problem? I can't understand how do we use WeightedRandomSampler with our targets.