imbue-ai / self_supervised

A Pytorch-Lightning implementation of self-supervised algorithms
MIT License
536 stars 52 forks source link

PyTorch-Lightning Implementation of Self-Supervised Learning Methods

This is a PyTorch Lightning implementation of the following self-supervised representation learning methods:

Supported datasets: ImageNet, STL-10, and CIFAR-10.

During training, the top1/top5 accuracies (out of 1+K examples) are reported where possible. During validation, an sklearn linear classifier is trained on half the test set and validated on the other half. The top1 accuracy is logged as train_class_acc / valid_class_acc.

Installing

Make sure you're in a fresh conda or venv environment, then run:

git clone https://github.com/untitled-ai/self_supervised
cd self_supervised
pip install -r requirements.txt

Replicating our BYOL blog post

We found some surprising results about the role of batch norm in BYOL. See the blog post Understanding self-supervised and contrastive learning with "Bootstrap Your Own Latent" (BYOL) for more details about our experiments.

You can replicate the results of our blog post by running python train_blog.py. The cosine similarity between z and z' is reported as step_neg_cos (for negative examples) and step_pos_cos (for positive examples). Classification accuracy is reported as valid_class_acc.

Getting started with MoCo v2

To get started with training a ResNet-18 with MoCo v2 on STL-10 (the default configuration):

import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import ModelParams

os.environ["DATA_PATH"] = "~/data"

params = ModelParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")

For convenience, you can instead pass these parameters as keyword args, for example with model = SelfSupervisedMethod(batch_size=128).

VICReg

To train VICReg rather than MoCo v2, use the following parameters:

import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import VICRegParams

os.environ["DATA_PATH"] = "~/data"

params = VICRegParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")

Note that we have not tuned these parameters for STL-10, and the parameters used for ImageNet are slightly different. See the comment on VICRegParams for details.

BYOL

To train BYOL rather than MoCo v2, use the following parameters:

import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import BYOLParams

os.environ["DATA_PATH"] = "~/data"

params = BYOLParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")

SimCLR

To train SimCLR rather than MoCo v2, use the following parameters:

import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import SimCLRParams

os.environ["DATA_PATH"] = "~/data"

params = SimCLRParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")

Note for multi-GPU setups: this currently only uses negatives on the same GPU, and will not sync negatives across multiple GPUs.

Evaluating a trained model

To train a linear classifier on the result:

import pytorch_lightning as pl
from linear_classifier import LinearClassifierMethod
linear_model = LinearClassifierMethod.from_moco_checkpoint("example.ckpt")
trainer = pl.Trainer(gpus=1, max_epochs=100)    

trainer.fit(linear_model)

Results on STL-10 and ImageNet

Training a ResNet-18 for 320 epochs on STL-10 achieved 85% linear classification accuracy on the test set (1 fold of 5000). This used all default parameters.

Training a ResNet-50 for 200 epochs on ImageNet achieves 65.6% linear classification accuracy on the test set. This used 8 gpus with ddp and parameters:

hparams = ModelParams(
    encoder_arch="resnet50",
    shuffle_batch_norm=True,
    embedding_dim=2048,
    mlp_hidden_dim=2048,
    dataset_name="imagenet",
    batch_size=32,
    lr=0.03,
    max_epochs=200,
    transform_crop_size=224,
    num_data_workers=32,
    gather_keys_for_queue=True,
)

(the batch_size differs from the moco documentation due to the way PyTorch-Lightning handles multi-gpu training in ddp - the effective number is batch_size=256). Note that for ImageNet we suggest using val_percent_check=0.1 when calling pl.Trainer to reduce the time fitting the sklearn model.

All training options

All possible hparams for SelfSupervisedMethod, along with defaults:

class ModelParams:
    # encoder model selection
    encoder_arch: str = "resnet18"
    shuffle_batch_norm: bool = False
    embedding_dim: int = 512  # must match embedding dim of encoder

    # data-related parameters
    dataset_name: str = "stl10"
    batch_size: int = 256

    # MoCo parameters
    K: int = 65536  # number of examples in queue
    dim: int = 128
    m: float = 0.996
    T: float = 0.2

    # eqco parameters
    eqco_alpha: int = 65536
    use_eqco_margin: bool = False
    use_negative_examples_from_batch: bool = False

    # optimization parameters
    lr: float = 0.5
    momentum: float = 0.9
    weight_decay: float = 1e-4
    max_epochs: int = 320
    final_lr_schedule_value: float = 0.0

    # transform parameters
    transform_s: float = 0.5
    transform_apply_blur: bool = True

    # Change these to make more like BYOL
    use_momentum_schedule: bool = False
    loss_type: str = "ce"
    use_negative_examples_from_queue: bool = True
    use_both_augmentations_as_queries: bool = False
    optimizer_name: str = "sgd"
    lars_warmup_epochs: int = 1
    lars_eta: float = 1e-3
    exclude_matching_parameters_from_lars: List[str] = []  # set to [".bias", ".bn"] to match paper
    loss_constant_factor: float = 1

    # Change these to make more like VICReg
    use_vicreg_loss: bool = False
    use_lagging_model: bool = True
    use_unit_sphere_projection: bool = True
    invariance_loss_weight: float = 25.0
    variance_loss_weight: float = 25.0
    covariance_loss_weight: float = 1.0
    variance_loss_epsilon: float = 1e-04

    # MLP parameters
    projection_mlp_layers: int = 2
    prediction_mlp_layers: int = 0
    mlp_hidden_dim: int = 512

    mlp_normalization: Optional[str] = None
    prediction_mlp_normalization: Optional[str] = "same"  # if same will use mlp_normalization
    use_mlp_weight_standardization: bool = False

    # data loader parameters
    num_data_workers: int = 4
    drop_last_batch: bool = True
    pin_data_memory: bool = True
    gather_keys_for_queue: bool = False

A few options require more explanation:

Training with custom options

You can train using any settings of the above parameters. This configuration represents the settings from BYOL:

hparams = ModelParams(
 prediction_mlp_layers=2,
 mlp_normalization="bn",
 loss_type="ip",
 use_negative_examples_from_queue=False,
 use_both_augmentations_as_queries=True,
 use_momentum_schedule=True,
 optimizer_name="lars",
 exclude_matching_parameters_from_lars=[".bias", ".bn"],
 loss_constant_factor=2
)

Or here is our recommended way to modify VICReg for CIFAR-10:

from model_params import VICRegParams

hparams = VICRegParams(
   dataset_name="cifar10",
   transform_apply_blur=False,
   mlp_hidden_dim=2048,
   dim=2048,
   batch_size=256,
   lr=0.3,
   final_lr_schedule_value=0,
   weight_decay=1e-4,
   lars_warmup_epochs=10,
   lars_eta=0.02
)