zz0320 / Flower102_CifarLT

1 stars 0 forks source link

想问下包含cifar-10 LT划分的方法代码吗? #1

Open liyang-good opened 2 years ago

zz0320 commented 2 years ago

你好 这里没有进行直接划分 是在读取数据的时候用DataLoader模拟长尾分布 可以参考这个代码 datasets/cifar_longtail.py

zz0320 commented 2 years ago
class CifarLTDataset(CifarDataset):
    def __init__(self, root_dir, transform=None, imb_factor=0.01, isTrain=True):
        """
        :param root_dir:
        :param transform:
        :param imb_type:
        :param imb_factor: float, 值越小,数量下降越快,0.1表示最少的类是最多的类的0.1倍,如500:5000
        :param isTrain:
        """
        super(CifarLTDataset, self).__init__(root_dir, transform=transform)
        self.imb_factor = imb_factor
        if isTrain:
            self.nums_per_cls = self._get_img_num_per_cls()     # 计算每个类的样本数
            self._select_img()      # 采样获得符合长尾分布的数据量
        else:
            # 非训练状态,可采用均衡数据集测试
            self.nums_per_cls = []
            for n in range(self.cls_num):
                label_list = [label for p, label in self.img_info]  # 获取每个标签
                self.nums_per_cls.append(label_list.count(n))       # 统计每个类别数量

    def _select_img(self):
        """
        根据每个类需要的样本数进行挑选
        :return:
        """
        new_lst = []
        for n, img_num in enumerate(self.nums_per_cls):
            lst_tmp = [info for info in self.img_info if info[1] == n]  # 获取第n类别数据信息
            random.shuffle(lst_tmp)
            lst_tmp = lst_tmp[:img_num]
            new_lst.extend(lst_tmp)
        random.shuffle(new_lst)
        self.img_info = new_lst

    def _get_img_num_per_cls(self):
        """
        依长尾分布计算每个类别应有多少张样本
        :return:
        """
        img_max = len(self.img_info) / self.cls_num
        img_num_per_cls = []
        for cls_idx in range(self.cls_num):
            num = img_max * (self.imb_factor ** (cls_idx / (self.cls_num - 1.0)))  # 列出公式就知道了
            img_num_per_cls.append(int(num))
        return img_num_per_cls