Open liyang-good opened 2 years ago
class CifarLTDataset(CifarDataset):
def __init__(self, root_dir, transform=None, imb_factor=0.01, isTrain=True):
"""
:param root_dir:
:param transform:
:param imb_type:
:param imb_factor: float, 值越小,数量下降越快,0.1表示最少的类是最多的类的0.1倍,如500:5000
:param isTrain:
"""
super(CifarLTDataset, self).__init__(root_dir, transform=transform)
self.imb_factor = imb_factor
if isTrain:
self.nums_per_cls = self._get_img_num_per_cls() # 计算每个类的样本数
self._select_img() # 采样获得符合长尾分布的数据量
else:
# 非训练状态,可采用均衡数据集测试
self.nums_per_cls = []
for n in range(self.cls_num):
label_list = [label for p, label in self.img_info] # 获取每个标签
self.nums_per_cls.append(label_list.count(n)) # 统计每个类别数量
def _select_img(self):
"""
根据每个类需要的样本数进行挑选
:return:
"""
new_lst = []
for n, img_num in enumerate(self.nums_per_cls):
lst_tmp = [info for info in self.img_info if info[1] == n] # 获取第n类别数据信息
random.shuffle(lst_tmp)
lst_tmp = lst_tmp[:img_num]
new_lst.extend(lst_tmp)
random.shuffle(new_lst)
self.img_info = new_lst
def _get_img_num_per_cls(self):
"""
依长尾分布计算每个类别应有多少张样本
:return:
"""
img_max = len(self.img_info) / self.cls_num
img_num_per_cls = []
for cls_idx in range(self.cls_num):
num = img_max * (self.imb_factor ** (cls_idx / (self.cls_num - 1.0))) # 列出公式就知道了
img_num_per_cls.append(int(num))
return img_num_per_cls
你好 这里没有进行直接划分 是在读取数据的时候用DataLoader模拟长尾分布 可以参考这个代码 datasets/cifar_longtail.py