TTN-YKK / Clustering_friendly_representation_learning

Other
53 stars 9 forks source link

About using ImageNet and ImageNet-Dog Dataset #3

Open V1oletM opened 2 years ago

V1oletM commented 2 years ago

Hi, Thank you for sharing your code! The results and idea are really helpful! When I'm trying to use ImageNet-10 by overwriting class with datasets.ImageFolder(), it doesn't work

class ImageNet10(Dataset):
    base_folder = 'imagenet-10'
    class_names_file = 'class_names.txt'
    train_list = [
        ['ImageNet10_112.h5', '918c2871b30a85fa023e0c44e0bee87f'],
    ]
    splits = ('train', 'test')

    def __init__(self, split='train',
                 transform=None, target_transform=None, download=False):
        if split not in self.splits:
            raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
                split, ', '.join(self.splits),
            ))

        self.transform = transform
        self.target_transform = target_transform
        self.split = split  # train/test/unlabeled set

        self.data, self.targets = self.__loadfile()
        print("Dataset Loaded.")

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img_size = (img.shape[0], img.shape[1])
        img = Image.fromarray(np.uint8(img)).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target, index

    def __len__(self):
        return len(self.data)

    def __loadfile(self):
        datas,labels = [],[]
        source_dataset = datasets.ImageFolder(root='datasets/imagenet-10/')

        for line,tar in zip(source_dataset.imgs,source_dataset.targets):
            try:
                img = io.imread(line[0])
                # img = color.gray2rgb(img)
            except:
                print(line[0])
                continue
            else:
                datas.append(img)
                labels.append(tar)

        return datas, labels

    def extra_repr(self):
        return "Split: {split}".format(**self.__dict__)

I will really appreciate your recommendations on more details about using ImageNet datasets and what parameters needed adjusted. Thank you!

TTN-YKK commented 2 years ago

If you have folder structure supported by ImageFolder class, try to use following ImageFolder class.

class ImageFolder(datasets.ImageFolder):    
    def __getitem__(self, index):
        img, target = super().__getitem__(index) 
        return img, target, index

For imagenet dataset crop size is change to 96x96. Other parameters are not changed.