This is a PyTorch Lightning implementation of the following self-supervised representation learning methods:
Supported datasets: ImageNet, STL-10, and CIFAR-10.
During training, the top1/top5 accuracies (out of 1+K examples) are reported where possible. During validation, an sklearn
linear classifier is trained on half the test set and validated on the other half. The top1 accuracy is logged as train_class_acc
/ valid_class_acc
.
Make sure you're in a fresh conda
or venv
environment, then run:
git clone https://github.com/untitled-ai/self_supervised
cd self_supervised
pip install -r requirements.txt
We found some surprising results about the role of batch norm in BYOL. See the blog post Understanding self-supervised and contrastive learning with "Bootstrap Your Own Latent" (BYOL) for more details about our experiments.
You can replicate the results of our blog post by running python train_blog.py
. The cosine similarity between z and z' is reported as step_neg_cos
(for negative examples) and step_pos_cos
(for positive examples). Classification accuracy is reported as valid_class_acc
.
To get started with training a ResNet-18 with MoCo v2 on STL-10 (the default configuration):
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import ModelParams
os.environ["DATA_PATH"] = "~/data"
params = ModelParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
For convenience, you can instead pass these parameters as keyword args, for example with model = SelfSupervisedMethod(batch_size=128)
.
To train VICReg rather than MoCo v2, use the following parameters:
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import VICRegParams
os.environ["DATA_PATH"] = "~/data"
params = VICRegParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
Note that we have not tuned these parameters for STL-10, and the parameters used for ImageNet are slightly different. See the comment on VICRegParams for details.
To train BYOL rather than MoCo v2, use the following parameters:
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import BYOLParams
os.environ["DATA_PATH"] = "~/data"
params = BYOLParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
To train SimCLR rather than MoCo v2, use the following parameters:
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import SimCLRParams
os.environ["DATA_PATH"] = "~/data"
params = SimCLRParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
Note for multi-GPU setups: this currently only uses negatives on the same GPU, and will not sync negatives across multiple GPUs.
To train a linear classifier on the result:
import pytorch_lightning as pl
from linear_classifier import LinearClassifierMethod
linear_model = LinearClassifierMethod.from_moco_checkpoint("example.ckpt")
trainer = pl.Trainer(gpus=1, max_epochs=100)
trainer.fit(linear_model)
Training a ResNet-18 for 320 epochs on STL-10 achieved 85% linear classification accuracy on the test set (1 fold of 5000). This used all default parameters.
Training a ResNet-50 for 200 epochs on ImageNet achieves 65.6% linear classification accuracy on the test set.
This used 8 gpus with ddp
and parameters:
hparams = ModelParams(
encoder_arch="resnet50",
shuffle_batch_norm=True,
embedding_dim=2048,
mlp_hidden_dim=2048,
dataset_name="imagenet",
batch_size=32,
lr=0.03,
max_epochs=200,
transform_crop_size=224,
num_data_workers=32,
gather_keys_for_queue=True,
)
(the batch_size
differs from the moco documentation due to the way PyTorch-Lightning handles multi-gpu
training in ddp
- the effective number is batch_size=256
). Note that for ImageNet we suggest using
val_percent_check=0.1
when calling pl.Trainer
to reduce the time fitting the sklearn model.
All possible hparams
for SelfSupervisedMethod, along with defaults:
class ModelParams:
# encoder model selection
encoder_arch: str = "resnet18"
shuffle_batch_norm: bool = False
embedding_dim: int = 512 # must match embedding dim of encoder
# data-related parameters
dataset_name: str = "stl10"
batch_size: int = 256
# MoCo parameters
K: int = 65536 # number of examples in queue
dim: int = 128
m: float = 0.996
T: float = 0.2
# eqco parameters
eqco_alpha: int = 65536
use_eqco_margin: bool = False
use_negative_examples_from_batch: bool = False
# optimization parameters
lr: float = 0.5
momentum: float = 0.9
weight_decay: float = 1e-4
max_epochs: int = 320
final_lr_schedule_value: float = 0.0
# transform parameters
transform_s: float = 0.5
transform_apply_blur: bool = True
# Change these to make more like BYOL
use_momentum_schedule: bool = False
loss_type: str = "ce"
use_negative_examples_from_queue: bool = True
use_both_augmentations_as_queries: bool = False
optimizer_name: str = "sgd"
lars_warmup_epochs: int = 1
lars_eta: float = 1e-3
exclude_matching_parameters_from_lars: List[str] = [] # set to [".bias", ".bn"] to match paper
loss_constant_factor: float = 1
# Change these to make more like VICReg
use_vicreg_loss: bool = False
use_lagging_model: bool = True
use_unit_sphere_projection: bool = True
invariance_loss_weight: float = 25.0
variance_loss_weight: float = 25.0
covariance_loss_weight: float = 1.0
variance_loss_epsilon: float = 1e-04
# MLP parameters
projection_mlp_layers: int = 2
prediction_mlp_layers: int = 0
mlp_hidden_dim: int = 512
mlp_normalization: Optional[str] = None
prediction_mlp_normalization: Optional[str] = "same" # if same will use mlp_normalization
use_mlp_weight_standardization: bool = False
# data loader parameters
num_data_workers: int = 4
drop_last_batch: bool = True
pin_data_memory: bool = True
gather_keys_for_queue: bool = False
A few options require more explanation:
encoder_arch can be any torchvision model, or can be one of the ResNet models with weight standardization defined in
ws_resnet.py
.
dataset_name can be imagenet
, stl10
, or cifar10
. os.environ["DATA_PATH"]
will be used as the path to the data. STL-10 and CIFAR-10 will
be downloaded if they do not already exist.
loss_type can be ce
(cross entropy) with one of the use_negative_examples
to correspond to MoCo or ip
(inner product)
with both use_negative_examples=False
to correspond to BYOL. It can also be bce
, which is similar to ip
but applies the
binary cross entropy loss function to the result. Or it can be vic
for VICReg loss.
optimizer_name, currently just sgd
or lars
.
exclude_matching_parameters_from_lars will remove weight decay and LARS learning rate from matching parameters. Set
to [".bias", ".bn"]
to match BYOL paper implementation.
mlp_normalization can be None for no normalization, bn
for batch normalization, ln
for layer norm, gn
for group
norm, or br
for batch renormalization.
prediction_mlp_normalization defaults to same
to use the same normalization as above, but can be given any of the
above parameters to use a different normalization.
shuffle_batch_norm and gather_keys_for_queue are both related to multi-gpu training. shuffle_batch_norm will shuffle the key images among GPUs, which is needed for training if batch norm is used. gather_keys_for_queue will gather key projections (z' in the blog post) from all gpus to add to the MoCo queue.
You can train using any settings of the above parameters. This configuration represents the settings from BYOL:
hparams = ModelParams(
prediction_mlp_layers=2,
mlp_normalization="bn",
loss_type="ip",
use_negative_examples_from_queue=False,
use_both_augmentations_as_queries=True,
use_momentum_schedule=True,
optimizer_name="lars",
exclude_matching_parameters_from_lars=[".bias", ".bn"],
loss_constant_factor=2
)
Or here is our recommended way to modify VICReg for CIFAR-10:
from model_params import VICRegParams
hparams = VICRegParams(
dataset_name="cifar10",
transform_apply_blur=False,
mlp_hidden_dim=2048,
dim=2048,
batch_size=256,
lr=0.3,
final_lr_schedule_value=0,
weight_decay=1e-4,
lars_warmup_epochs=10,
lars_eta=0.02
)