VICO-UoE / DatasetCondensation

Dataset Condensation (ICLR21 and ICML21)
MIT License
473 stars 91 forks source link

how do you calculate the mean and std? It seems that it is different when using dataset to calculate #36

Open kekekeke8 opened 3 months ago

kekekeke8 commented 3 months ago

def get_dataset(dataset, data_path): if dataset == 'MNIST': channel = 1 im_size = (28, 28) num_classes = 10 mean = [0.1307] std = [0.3081] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform) class_names = [str(c) for c in range(num_classes)]

elif dataset == 'FashionMNIST':
    channel = 1
    im_size = (28, 28)
    num_classes = 10
    mean = [0.2861]
    std = [0.3530]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    dst_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform) # no augmentation
    dst_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)
    class_names = dst_train.classes

elif dataset == 'SVHN':
    channel = 3
    im_size = (32, 32)
    num_classes = 10
    mean = [0.4377, 0.4438, 0.4728]
    std = [0.1980, 0.2010, 0.1970]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    dst_train = datasets.SVHN(data_path, split='train', download=True, transform=transform)  # no augmentation
    dst_test = datasets.SVHN(data_path, split='test', download=True, transform=transform)
    class_names = [str(c) for c in range(num_classes)]

elif dataset == 'CIFAR10':
    channel = 3
    im_size = (32, 32)
    num_classes = 10
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation
    dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
    class_names = dst_train.classes

elif dataset == 'CIFAR100':
    channel = 3
    im_size = (32, 32)
    num_classes = 100
    mean = [0.5071, 0.4866, 0.4409]
    std = [0.2673, 0.2564, 0.2762]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform) # no augmentation
    dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform)
    class_names = dst_train.classes