ain-soph / trojanzoo

TrojanZoo provides a universal pytorch platform to conduct security researches (especially backdoor attacks/defenses) of image classification in deep learning.
https://ain-soph.github.io/trojanzoo
GNU General Public License v3.0
274 stars 62 forks source link

Low effective loading in get_class_subset function #159

Open TDteach opened 2 years ago

TDteach commented 2 years ago

Currently, the get_class_subset function in trojanzoo.datasets.Dataset directly use the function get_class_subset in trojanzoo.utils.data.py. However, the function get_class_subset in trojanzoo.utils.data.py runs in low efficiency, especially for ImageNet data. It loads the whole dataset including images and labels through this line. And, only labels will be used subsequently.

I suggest to use the following code to replace this function to avoid the useless loading.

        class_list = [class_list] if isinstance(class_list, int) else class_list
        indices = np.arange(len(dataset))
        if isinstance(dataset, Subset):
            idx = np.array(dataset.indices)
            indices = idx[indices]
            dataset = dataset.dataset

        if self.target_transform is not None:
            targets = [dataset.target_transform(t) for t in dataset.targets]
        else:
            targets = dataset.targets
        targets = np.asarray(targets)
        idx_bool = np.isin(targets, class_list)
        idx = np.arange(len(dataset))[idx_bool]
        idx = np.intersect1d(idx, indices)
        return Subset(dataset, idx)
ain-soph commented 2 years ago

Your provided implementation is based on the assumption that dataset is actually torchvision.datasets.VisionDataset rather than torch.utils.data.Dataset (especially for Subset), while this assumption is not always true.

Not every dataset has target_transform method or targets attribute.

Especially since pytorch team is gradually deprecating the dataset convention and use the new datapipe style, I don't think it's a good idea to make trojanzoo function rely on such concrete internal method and attribute.


But what you claim is correct, current implementation is too slow for ImageNet.
It'll be perfect if we can find a solution that works for future ImageNet dataset as well. The old ImageNet dataset (ImageFolder style) will be deprecated next year after pytorch 2.0 .