sthalles / SimCLR

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
https://sthalles.github.io/simple-self-supervised-learning/
MIT License
2.18k stars 454 forks source link

How do i train the SimCLR model with my local dataset? #38

Closed bestalllen closed 2 years ago

bestalllen commented 3 years ago

Dear researcher, Thank you for the open-source code you provided, it is of great help to me for understanding contrastive learning. But I still have some confusion when training the SimCLR model with my local dataset, could you give me some guidance or tips? I would appreciate it if you could reply to this issue.

pengzhangzhi commented 2 years ago

Hi! I am no the author, but I would like to give you some advice. To use your own dataset, the only part you need to modify is here:

`
def get_dataset(self, name, n_views): valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True, transform=ContrastiveLearningViewGenerator( self.get_simclr_pipeline_transform(32), n_views), download=True),

                      'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
                                                      transform=ContrastiveLearningViewGenerator(
                                                          self.get_simclr_pipeline_transform(96),
                                                          n_views),
                                                      download=True)}

these lines of code are to create a dataset, where the author uses the public dataset from pytorch. In your case, you should write a pytorch dataset class and replace these code. Note that you don't forget to include transform=ContrastiveLearningViewGenerator( self.get_simclr_pipeline_transform(your image size), n_views),`

Hope this can help~

bestalllen commented 2 years ago

Hi! I am no the author, but I would like to give you some advice. To use your own dataset, the only part you need to modify is here:

` def get_dataset(self, name, n_views): valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True, transform=ContrastiveLearningViewGenerator( self.get_simclr_pipeline_transform(32), n_views), download=True),

                      'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
                                                      transform=ContrastiveLearningViewGenerator(
                                                          self.get_simclr_pipeline_transform(96),
                                                          n_views),
                                                      download=True)}

these lines of code are to create a dataset, where the author uses the public dataset from pytorch. In your case, you should write a pytorch dataset class and replace these code. Note that you don't forget to includetransform=ContrastiveLearningViewGenerator( self.get_simclr_pipeline_transform(your image size), n_views),`

Hope this can help~

Ok, thanks for your reply. I have implemented this experiment on my own dataset, still thanks for your help!

LifeIsBright-heihei commented 2 years ago

hello, how do you change the dataset,could you give me some guidance or tips? I would appreciate it if you could reply to this issue.

pengzhangzhi commented 2 years ago

hello, how do you change the dataset,could you give me some guidance or tips? I would appreciate it if you could reply to this issue.

I would like to help. Would you like to provide me with a more specific problem that you came accross.

ElerCode commented 1 year ago

Hello, I want to use my data set, the current institution is like this

From data0 to data9,There are 10 categories.
How do I generate an author-like dataset with class_names.txt fold_indices.txt test_X.bin and so on I would appreciate it if you could reply to this issue.