mingkai-zheng / ReSSL

ReSSL: Relational Self-Supervised Learning with Weak Augmentation
57 stars 8 forks source link

Pretrain on CIFAR-10, CIFAR-100, STL-10, Tiny ImageNet #5

Closed gyfastas closed 3 years ago

gyfastas commented 3 years ago

Thank you for your great work! I notice that results on small and medium datasets (i.e. CIFAR-10, CIFAR-100, STL-10, Tiny ImageNet) are provided in your paper. Can you provide pretraining configs on these datasets?

mingkai-zheng commented 3 years ago

Sure, for small and medium datasets, we have the following configuration

Please noted that our experiments for small and medium datasets are performed on a single GPU, so for simulating the shuffle bn trick, we simply follow the implementation from MoCo-Cifar

Contrastive Augmentation

class GaussianBlur(object):
    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

def get_contrastive_augment(dataset):
    size = 32
    if dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif dataset == 'stl10':
        mean = (0.4408, 0.4279, 0.3867)
        std = (0.2682, 0.2610, 0.2686)
        size = 64
    elif dataset == 'tinyimagenet':
        mean = (0.4802, 0.4481, 0.3975)
        std = (0.2302, 0.2265, 0.2262)
        size = 64
    else:
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)

    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=size, scale=(0.2, 1)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
            transforms.ToTensor(),
            normalize,
        ])

    return train_transform

Weak Augmentation

def get_weak_augment(dataset):
    size = 32
    if dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif dataset == 'stl10':
        mean = (0.4408, 0.4279, 0.3867)
        std = (0.2682, 0.2610, 0.2686)
        size = 64
    elif dataset == 'tinyimagenet':
        mean = (0.4802, 0.4481, 0.3975)
        std = (0.2302, 0.2265, 0.2262)
        size = 64
    else:
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)

    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=size, scale=(0.2, 1)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    return train_transform

HyperPrameters

lr = 0.06
weight_decay = 1e-4
total_epochs = 200
warmup = 5 epochs
cosine scheduler
momentum = 0.9
batch_size = 256
teacher_temperature = 0.04
student_temperature = 0.1 

# Projection_head
nn.Sequential(        
      nn.Linear(dim_in, 2048),
      nn.ReLU(True),
      nn.Linear(2048, 128)
  )

# small dataset
memor_buffer_size = 4096
ema_momentum = 0.99

# medium dataset
memor_buffer_size = 16384
ema_momentum = 0.996
mingkai-zheng commented 2 years ago

Just a kind reminder, I have released the code and pre-trained model for cifar10/100 and STL10, you can download it from this link