Accompanying code for Manifold Contrastive Learning with Variational Lie Group Operators by Kion Fallah, Alec Helbling, Kyle A. Johnsen, and Christopher J. Rozell.
The main script of interest is src/experiment.py
. This takes one of the hyrda
configs (from the configs
folder) as input and conducts a contrastive learning experiment. Modules related to the Lie group operators and variational inference is in src/model/manifold
. The main contrastive logic for ManifoldCLR is in src/model/contrastive/transop_header.py
.
Python 3.9.13
PyTorch 1.12.1
matplotlib 3.5.3
wandb 0.13.5
scikit-learn 1.1.2
hydra-core 1.2.0
lightly 1.2.47
All datasets should be placed in a folder named datasets
. TinyImagenet can be downloaded using this script: https://gist.github.com/moskomule/2e6a9a463f50447beca4e64ab4699ac4
. When downloaded, the dataset should be placed in datasets
with the following structure:
* datasets
* tiny-imagenet-200
* test
...
* train
...
* val
...
Set projection head to MLP
, None
, Linear
, or Direct
(DirectCLR)
python src/experiment.py --config-name simclr_cifar10 \
++model_cfg.header_cfg.projection_header_cfg.projection_type=MLP
python src/experiment.py --config-name simclr_stl10 \
++model_cfg.header_cfg.projection_header_cfg.projection_type=MLP
python src/experiment.py --config-name simclr_tinyimagenet \
++model_cfg.header_cfg.projection_header_cfg.projection_type=MLP
Note that the experiments in the paper used the soft-thresholding based VI config.
python src/experiment.py --config-name transop_vi_proj_cifar10
python src/experiment.py --config-name transop_vi_proj_stl10
python src/experiment.py --config-name transop_vi_proj_tin
python src/experiment.py --config-name transop_vi_cifar10
python src/experiment.py --config-name transop_vi_stl10
python src/experiment.py --config-name transop_vi_tin
python src/experiment.py --config-name transop_vi_cifar10 \
++model_cfg.header_cfg.transop_header_cfg.enable_block_diagonal=false \
++model_cfg.header_cfg.transop_header_cfg.enable_direct=True \
++evaluator_cfg.aug_nn_eval_cfg.enable_runner=false
python src/experiment.py --config-name transop_vi_stl10 \
++model_cfg.header_cfg.transop_header_cfg.enable_block_diagonal=false \
++model_cfg.header_cfg.transop_header_cfg.enable_direct=True \
++evaluator_cfg.aug_nn_eval_cfg.enable_runner=false
python src/experiment.py --config-name transop_vi_tin \
++model_cfg.header_cfg.transop_header_cfg.enable_block_diagonal=false \
++model_cfg.header_cfg.transop_header_cfg.enable_direct=True \
++evaluator_cfg.aug_nn_eval_cfg.enable_runner=false
Several configs are included for incorporating soft-thresholding to get machine precision sparsity in the inferred coefficients for the Lie group operators. These techniques use the methods from Variational Sparse Coding with Learned Threhsolding by Fallah and Rozell. They rely on "max ELBO sampling" (see text for details) which requires large amounts of GPU VRAM. Furthermore, these methods benefit from L2 regularization on the coefficients to increase stability. Note that the CIFAR10 experiments in the paper used this soft-thresholding strategy wit 20 samples. To see a strong benefit with other datasets, it may be neccesary to use upwards of 100 samples (perhaps using 4 GPUs).
The configs to use are:
transop_vi-thresh_proj_cifar10
transop_vi-thresh_proj_stl10
transop_vi-thresh_proj_tin
To get PyTorch nn.DataParallel to work with int
data types when using multiple GPUs, you may require the following hot-fix.
In torch/nn/parallel/scatter_gather.py
, add the following lines:
if isinstance(out, int):
return out
in the function gather
under the line:
out = outputs[0]
The main script to run the semi-supervised experiments is src/eval_ssl.py
. The hyper-parameters used for this experiment and the path to the model weights are hard-coded at the top of the file. This script relies on freezing the backbone of a contrastive pre-trained model and training a MLP on top.
To run this script, change ckpt_path
to point to the pretrained model checkpoint, cfg_path
to point to the config file (usually located in the .hydra
folder with the results), and set cfg.dataloader_cfg.dataset_cfg.dataset_dir
to the directory of your datasets. Set feat_aug
to Transop, Featmatch, None, Mixup
for VLGO, Featmatch, Pseudo-labeling, and MMICT augmentations. For the baseline, set feat_aug
to None
and set con_weight=0
.
There is a fine-tuning script that updates the backbone weights with src/eval_ssl_finetune.py
, but this script results in overfitting in most cases.