Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch. In addition to doing contrastive learning on the pixel level, the online network further passes the pixel level representations to a Pixel Propagation Module and enforces a similarity loss to the target network. They beat all previous unsupervised and supervised methods in segmentation tasks.
$ pip install pixel-level-contrastive-learning
Below is an example of how you would use the framework to self-supervise training of a resnet, taking the output of layer 4 (8 x 8 'pixels').
import torch
from pixel_level_contrastive_learning import PixelCL
from torchvision import models
from tqdm import tqdm
resnet = models.resnet50(pretrained=True)
learner = PixelCL(
resnet,
image_size = 256,
hidden_layer_pixel = 'layer4', # leads to output of 8x8 feature map for pixel-level learning
hidden_layer_instance = -2, # leads to output for instance-level learning
projection_size = 256, # size of projection output, 256 was used in the paper
projection_hidden_size = 2048, # size of projection hidden dimension, paper used 2048
moving_average_decay = 0.99, # exponential moving average decay of target encoder
ppm_num_layers = 1, # number of layers for transform function in the pixel propagation module, 1 was optimal
ppm_gamma = 2, # sharpness of the similarity in the pixel propagation module, already at optimal value of 2
distance_thres = 0.7, # ideal value is 0.7, as indicated in the paper, which makes the assumption of each feature map's pixel diagonal distance to be 1 (still unclear)
similarity_temperature = 0.3, # temperature for the cosine similarity for the pixel contrastive loss
alpha = 1., # weight of the pixel propagation loss (pixpro) vs pixel CL loss
use_pixpro = True, # do pixel pro instead of pixel contrast loss, defaults to pixpro, since it is the best one
cutout_ratio_range = (0.6, 0.8) # a random ratio is selected from this range for the random cutout
).cuda()
opt = torch.optim.Adam(learner.parameters(), lr=1e-4)
def sample_batch_images():
return torch.randn(10, 3, 256, 256).cuda()
for _ in tqdm(range(100000)):
images = sample_batch_images()
loss = learner(images) # if positive pixel pairs is equal to zero, the loss is equal to the instance level loss
opt.zero_grad()
loss.backward()
print(loss.item())
opt.step()
learner.update_moving_average() # update moving average of target encoder
# after much training, save the improved model for testing on downstream task
torch.save(resnet, 'improved-resnet.pt')
You can also return the number of positive pixel pairs on forward
, for logging or other purposes
loss, positive_pairs = learner(images, return_positive_pairs = True)
@misc{xie2020propagate,
title={Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning},
author={Zhenda Xie and Yutong Lin and Zheng Zhang and Yue Cao and Stephen Lin and Han Hu},
year={2020},
eprint={2011.10043},
archivePrefix={arXiv},
primaryClass={cs.CV}
}