Open kekekeke8 opened 3 months ago
def get_dataset(dataset, data_path): if dataset == 'MNIST': channel = 1 im_size = (28, 28) num_classes = 10 mean = [0.1307] std = [0.3081] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform) class_names = [str(c) for c in range(num_classes)]
elif dataset == 'FashionMNIST': channel = 1 im_size = (28, 28) num_classes = 10 mean = [0.2861] std = [0.3530] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) dst_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform) # no augmentation dst_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform) class_names = dst_train.classes elif dataset == 'SVHN': channel = 3 im_size = (32, 32) num_classes = 10 mean = [0.4377, 0.4438, 0.4728] std = [0.1980, 0.2010, 0.1970] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) dst_train = datasets.SVHN(data_path, split='train', download=True, transform=transform) # no augmentation dst_test = datasets.SVHN(data_path, split='test', download=True, transform=transform) class_names = [str(c) for c in range(num_classes)] elif dataset == 'CIFAR10': channel = 3 im_size = (32, 32) num_classes = 10 mean = [0.4914, 0.4822, 0.4465] std = [0.2023, 0.1994, 0.2010] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform) class_names = dst_train.classes elif dataset == 'CIFAR100': channel = 3 im_size = (32, 32) num_classes = 100 mean = [0.5071, 0.4866, 0.4409] std = [0.2673, 0.2564, 0.2762] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform) # no augmentation dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform) class_names = dst_train.classes
def get_dataset(dataset, data_path): if dataset == 'MNIST': channel = 1 im_size = (28, 28) num_classes = 10 mean = [0.1307] std = [0.3081] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform) class_names = [str(c) for c in range(num_classes)]