wyc-ruiker / HKD

Code for ICCV 2021 paper "Distilling Holistic Knowledge with Graph Neural Networks"
42 stars 7 forks source link

About tinyimage dataset #2

Closed cvv-student closed 2 years ago

cvv-student commented 2 years ago

Very excellent work! Where can I get the code or pre-trained models on tinyimagenet dataset, please?

wyc-ruiker commented 2 years ago

Very excellent work! Where can I get the code or pre-trained models on tinyimagenet dataset, please?

Thank you for your attention to our work! I found the experimental code of TinyImageNet from my historical code, you can add it as ./dataset/tinyimagenet.py

from __future__ import print_function

import os
import numpy as np
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import extract_archive, check_integrity, download_url, verify_str_arg
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class TinyImageNet(VisionDataset):
    """`tiny-imageNet <http://cs231n.stanford.edu/tiny-imagenet-200.zip>`_ Dataset.
        Args:
            root (string): Root directory of the dataset.
            split (string, optional): The dataset split, supports ``train``, or ``val``.
            transform (callable, optional): A function/transform that  takes in an PIL image
               and returns a transformed version. E.g, ``transforms.RandomCrop``
            target_transform (callable, optional): A function/transform that takes in the
               target and transforms it.
            download (bool, optional): If true, downloads the dataset from the internet and
               puts it in root directory. If dataset is already downloaded, it is not
               downloaded again.
    """
    base_folder = 'tiny-imagenet-200/'
    url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
    filename = 'tiny-imagenet-200.zip'
    md5 = '90528d7ca1a48142e341f4ef8d21d0de'

    def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
        super(TinyImageNet, self).__init__(root, transform=transform, target_transform=target_transform)

        self.dataset_path = os.path.join(root, self.base_folder)
        self.loader = default_loader
        self.split = verify_str_arg(split, "split", ("train", "val",))

        if self._check_integrity():
            print('Files already downloaded and verified.')
        elif download:
            self._download()
        else:
            raise RuntimeError(
                'Dataset not found. You can use download=True to download it.')
        if not os.path.isdir(self.dataset_path):
            print('Extracting...')
            extract_archive(os.path.join(root, self.filename))

        _, class_to_idx = find_classes(os.path.join(self.dataset_path, 'wnids.txt'))

        self.data = make_dataset(self.root, self.base_folder, self.split, class_to_idx)

    def _download(self):
        print('Downloading...')
        download_url(self.url, root=self.root, filename=self.filename)
        print('Extracting...')
        extract_archive(os.path.join(self.root, self.filename))

    def _check_integrity(self):
        return check_integrity(os.path.join(self.root, self.filename), self.md5)

    def __getitem__(self, index):
        img_path, target = self.data[index]
        image = self.loader(img_path)

        if self.transform is not None:
            image = self.transform(image)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return image, target

    def __len__(self):
        return len(self.data)

def find_classes(class_file):
    with open(class_file) as r:
        classes = list(map(lambda s: s.strip(), r.readlines()))

    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}

    return classes, class_to_idx

def make_dataset(root, base_folder, dirname, class_to_idx):
    images = []
    dir_path = os.path.join(root, base_folder, dirname)

    if dirname == 'train':
        for fname in sorted(os.listdir(dir_path)):
            cls_fpath = os.path.join(dir_path, fname)
            if os.path.isdir(cls_fpath):
                cls_imgs_path = os.path.join(cls_fpath, 'images')
                for imgname in sorted(os.listdir(cls_imgs_path)):
                    path = os.path.join(cls_imgs_path, imgname)
                    item = (path, class_to_idx[fname])
                    images.append(item)
    else:
        imgs_path = os.path.join(dir_path, 'images')
        imgs_annotations = os.path.join(dir_path, 'val_annotations.txt')

        with open(imgs_annotations) as r:
            data_info = map(lambda s: s.split('\t'), r.readlines())

        cls_map = {line_data[0]: line_data[1] for line_data in data_info}

        for imgname in sorted(os.listdir(imgs_path)):
            path = os.path.join(imgs_path, imgname)
            item = (path, class_to_idx[cls_map[imgname]])
            images.append(item)

    return images

class TinyImageNetInstance(TinyImageNet):
    def __getitem__(self, index):
        img_path, target = self.data[index]
        image = self.loader(img_path)

        if self.transform is not None:
            image = self.transform(image)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return image, target, index

def get_tiny_imagenet_dataloaders(batch_size=128, num_workers=8, isinstance=True):

    train_transform = transforms.Compose([
        transforms.Resize(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
    test_transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor()
    ])

    if isinstance:
        train_dataset = TinyImageNetInstance('/disk/wyc/tiny-imagenet', split='train', transform=train_transform)
    else:
        train_dataset = TinyImageNet('/disk/wyc/tiny-imagenet', split='train', transform=train_transform)

    test_dataset = TinyImageNet('/disk/wyc/tiny-imagenet', split='val', transform=test_transform)

    n_data = len(train_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=int(batch_size/2),
                             shuffle=False,
                             num_workers=int(num_workers/2))
    if isinstance:
        return train_loader, test_loader, n_data
    else:
        return train_loader, test_loader

class TinyImageNetInstanceSample(TinyImageNet):
    def __init__(self, root, split, transform=None, target_transform=None,
                 download=False, k=4096, mode='exact', is_sample=True, percent=1.0):
        super().__init__(root=root, split=split, download=download,
                         transform=transform, target_transform=target_transform)
        self.k = k
        self.mode = mode
        self.is_sample = is_sample

        num_classes = 200
        num_samples = len(self.data)
        label = []
        for i in range(num_samples):
            _, l = self.data[i]
            label.append(l)

        self.cls_positive = [[] for i in range(num_classes)]
        for i in range(num_samples):
            self.cls_positive[label[i]].append(i)

        self.cls_negative = [[] for i in range(num_classes)]
        for i in range(num_classes):
            for j in range(num_classes):
                if j == i:
                    continue
                self.cls_negative[i].extend(self.cls_positive[j])

        self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
        self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]

        if 0 < percent < 1:
            n = int(len(self.cls_negative[0]) * percent)
            self.cls_negative = [np.random.permutation(self.cls_negative[i])[0:n]
                                 for i in range(num_classes)]

        self.cls_positive = np.asarray(self.cls_positive)
        self.cls_negative = np.asarray(self.cls_negative)

    def __getitem__(self, index):
        img_path, target = self.data[index]
        img = self.loader(img_path)

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        if not self.is_sample:
            # directly return
            return img, target, index
        else:
            # sample contrastive examples
            if self.mode in ['exact', 'hkd']:
                pos_idx = index
            elif self.mode == 'relax':
                pos_idx = np.random.choice(self.cls_positive[target], 1)
                pos_idx = pos_idx[0]
            else:
                raise NotImplementedError(self.mode)
            replace = True if self.k > len(self.cls_negative[target]) else False
            neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
            if self.mode == 'hkd':
                sample_idx = neg_idx
            else:
                sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
            return img, target, index, sample_idx

def get_tiny_imagenet_dataloaders_sample(batch_size=128, num_workers=8, k=4096, mode='exact',
                                    is_sample=True, percent=1.0):
    train_transform = transforms.Compose([
        transforms.Resize(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
    test_transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor()
    ])

    train_dataset = TinyImageNetInstanceSample('/disk/wyc/tiny-imagenet', split='train', transform=train_transform, k=k, mode=mode, is_sample=is_sample, percent=percent)
    test_dataset = TinyImageNet('/disk/wyc/tiny-imagenet', split='val', transform=test_transform)

    n_data = len(train_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=int(batch_size/2),
                             shuffle=False,
                             num_workers=int(num_workers/2))
    return train_loader, test_loader, n_data

And you can use it like CIFAR-100 data loader, for example:

    if opt.dataset == 'cifar100':
           train_loader, val_loader, n_data = get_cifar100_dataloaders(batch_size=opt.batch_size,
                                                                        num_workers=opt.num_workers,
                                                                        is_instance=True)
    elif opt.dataset == 'tiny-imagenet':
        if opt.distill in ['hkd']:
            train_loader, val_loader, n_data = get_tiny_imagenet_dataloaders_sample(batch_size=opt.batch_size,
                                                                               num_workers=opt.num_workers,
                                                                               k=opt.nce_k,
                                                                               mode='hkd')
        else:
            train_loader, val_loader, n_data = get_tiny_imagenet_dataloaders(batch_size=opt.batch_size,
                                                                        num_workers=opt.num_workers)
cvv-student commented 2 years ago

Thank you for your patient answer! I would like to reproduce your work. If I retrain the tinyimage teacher model it may lead to misalignment. Where can I get your pre-trained teacher model on tinyimage?

wyc-ruiker commented 2 years ago

Thank you for your patient answer! I would like to reproduce your work. If I retrain the tinyimage teacher model it may lead to misalignment. Where can I get your pre-trained teacher model on tinyimage?

resnet50: https://drive.google.com/file/d/1ChIdWutl20JBG_0n2NzkmP2kp9U73PVq/view?usp=sharing resnet32x4: https://drive.google.com/file/d/1ziDdzilsA0O99JaH_cM3zx3T2MhIYK9F/view?usp=sharing vgg13: https://drive.google.com/file/d/1_lpJjKWgVrk2iU55Nsgfh3w8nb5JRG2F/view?usp=sharing Because I don't have computing resources now, I'm not sure whether this pre-trained parameter is consistent with that in the paper. Please feel free to let me know if it is inconsistent or not available.

cvv-student commented 2 years ago

Thank you for your patient answer! I would like to reproduce your work. If I retrain the tinyimage teacher model it may lead to misalignment. Where can I get your pre-trained teacher model on tinyimage?

resnet50: https://drive.google.com/file/d/1ChIdWutl20JBG_0n2NzkmP2kp9U73PVq/view?usp=sharing resnet32x4: https://drive.google.com/file/d/1ziDdzilsA0O99JaH_cM3zx3T2MhIYK9F/view?usp=sharing vgg13: https://drive.google.com/file/d/1_lpJjKWgVrk2iU55Nsgfh3w8nb5JRG2F/view?usp=sharing Because I don't have computing resources now, I'm not sure whether this pre-trained parameter is consistent with that in the paper. Please feel free to let me know if it is inconsistent or not available.

The teacher models are fully aligned with the precision of the paper. Thank you for your generous help!