sthalles / PyTorch-BYOL

PyTorch implementation of Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning
480 stars 72 forks source link

Training on CIFAR10 #8

Open akhauriyash opened 3 years ago

akhauriyash commented 3 years ago

Hello,

Thank you for this excellent repository!

Do you have any suggestions of changes to make to train BYOL on the CIFAR10 dataset?

The way I am doing this (in main.py) (I am also training my own custom models, but I do not think that is too relevant)

DATASET='CIFAR10' # Can change to STL10

if DATASET=='STL10':
    train_dataset = datasets.STL10('/workspace/STLDataset', split='train+unlabeled', download=True,
                                    transform=MultiViewDataInjector([data_transform, data_transform]))
elif DATASET=='CIFAR10':
    train_dataset = datasets.CIFAR10('/workspace/CIFAR10Dataset', train=True, download=True,
                                    transform=MultiViewDataInjector([data_transform, data_transform]))
else:
    print("Error, dataset not supported, choose CIFAR10 or STL10")
    exit(0)

I also change the config to have: input_shape: (32,32,3).

Further, I may not have taken a very deep look into this code-base, but how do we produce the 'STL10 Top 1' accuracies(75.2%) after training the model on the self-supervised task? Do we take the trained model and fine-tune on the STL10 supervised dataset? I assume that code is not included in this library?

Thank you!

khangt1k25 commented 3 years ago

Hi Akhauriyash, you can just modify the input shape and name of the dataset. I am testing with the model but it doesn't work well with CIFAR10, ~ 54% top1 accuracy and I wonder the config is the same or different on learning rate? Thank you!