kfallah / manifold-contrastive

Contrastive Learning using Manifold Learning
0 stars 0 forks source link

Manifold Contrastive Learning

Intro Figure

Accompanying code for Manifold Contrastive Learning with Variational Lie Group Operators by Kion Fallah, Alec Helbling, Kyle A. Johnsen, and Christopher J. Rozell.

Repo Overview

The main script of interest is src/experiment.py. This takes one of the hyrda configs (from the configs folder) as input and conducts a contrastive learning experiment. Modules related to the Lie group operators and variational inference is in src/model/manifold. The main contrastive logic for ManifoldCLR is in src/model/contrastive/transop_header.py.

Dependencies

Python          3.9.13
PyTorch         1.12.1
matplotlib      3.5.3
wandb           0.13.5
scikit-learn    1.1.2
hydra-core      1.2.0
lightly         1.2.47

All datasets should be placed in a folder named datasets. TinyImagenet can be downloaded using this script: https://gist.github.com/moskomule/2e6a9a463f50447beca4e64ab4699ac4. When downloaded, the dataset should be placed in datasets with the following structure:

* datasets
    * tiny-imagenet-200
        * test
            ...
        * train
            ...
        * val
            ...

Running Contrastive Experiments

Set projection head to MLP, None, Linear, or Direct (DirectCLR)

Baselines

CIFAR10

python src/experiment.py --config-name simclr_cifar10 \
    ++model_cfg.header_cfg.projection_header_cfg.projection_type=MLP

STL10

python src/experiment.py --config-name simclr_stl10 \
    ++model_cfg.header_cfg.projection_header_cfg.projection_type=MLP

TinyImagenet

python src/experiment.py --config-name simclr_tinyimagenet \
    ++model_cfg.header_cfg.projection_header_cfg.projection_type=MLP

ManifoldCLR with Proj Head

CIFAR10

Note that the experiments in the paper used the soft-thresholding based VI config.

python src/experiment.py --config-name transop_vi_proj_cifar10

STL10

python src/experiment.py --config-name transop_vi_proj_stl10

TinyImagenet

python src/experiment.py --config-name transop_vi_proj_tin

ManifoldCLR without Proj Head

CIFAR10

python src/experiment.py --config-name transop_vi_cifar10

STL10

python src/experiment.py --config-name transop_vi_stl10

TinyImagenet

python src/experiment.py --config-name transop_vi_tin

ManifoldDirectCLR

CIFAR10

python src/experiment.py --config-name transop_vi_cifar10 \
    ++model_cfg.header_cfg.transop_header_cfg.enable_block_diagonal=false \
    ++model_cfg.header_cfg.transop_header_cfg.enable_direct=True \
    ++evaluator_cfg.aug_nn_eval_cfg.enable_runner=false

STL10

python src/experiment.py --config-name transop_vi_stl10 \
    ++model_cfg.header_cfg.transop_header_cfg.enable_block_diagonal=false \
    ++model_cfg.header_cfg.transop_header_cfg.enable_direct=True \
    ++evaluator_cfg.aug_nn_eval_cfg.enable_runner=false

TinyImagenet

python src/experiment.py --config-name transop_vi_tin \
    ++model_cfg.header_cfg.transop_header_cfg.enable_block_diagonal=false \
    ++model_cfg.header_cfg.transop_header_cfg.enable_direct=True \
    ++evaluator_cfg.aug_nn_eval_cfg.enable_runner=false

ManifoldCLR with Soft-thresholding

Several configs are included for incorporating soft-thresholding to get machine precision sparsity in the inferred coefficients for the Lie group operators. These techniques use the methods from Variational Sparse Coding with Learned Threhsolding by Fallah and Rozell. They rely on "max ELBO sampling" (see text for details) which requires large amounts of GPU VRAM. Furthermore, these methods benefit from L2 regularization on the coefficients to increase stability. Note that the CIFAR10 experiments in the paper used this soft-thresholding strategy wit 20 samples. To see a strong benefit with other datasets, it may be neccesary to use upwards of 100 samples (perhaps using 4 GPUs).

The configs to use are:

transop_vi-thresh_proj_cifar10
transop_vi-thresh_proj_stl10
transop_vi-thresh_proj_tin

To get PyTorch nn.DataParallel to work with int data types when using multiple GPUs, you may require the following hot-fix.

In torch/nn/parallel/scatter_gather.py, add the following lines:

if isinstance(out, int):
    return out

in the function gather under the line:

out = outputs[0]

Running Semi-supervised Experiments

The main script to run the semi-supervised experiments is src/eval_ssl.py. The hyper-parameters used for this experiment and the path to the model weights are hard-coded at the top of the file. This script relies on freezing the backbone of a contrastive pre-trained model and training a MLP on top.

To run this script, change ckpt_path to point to the pretrained model checkpoint, cfg_path to point to the config file (usually located in the .hydra folder with the results), and set cfg.dataloader_cfg.dataset_cfg.dataset_dir to the directory of your datasets. Set feat_aug to Transop, Featmatch, None, Mixup for VLGO, Featmatch, Pseudo-labeling, and MMICT augmentations. For the baseline, set feat_aug to None and set con_weight=0.

There is a fine-tuning script that updates the backbone weights with src/eval_ssl_finetune.py, but this script results in overfitting in most cases.