mlbio-epfl / turtle

[ICML 2024] Let Go of Your Labels with Unsupervised Transfer
https://brbiclab.epfl.ch/projects/turtle/
45 stars 5 forks source link
clustering deep-learning foundation-models icml icml-2024 implicit-bias maximum-margin-learning transfer-learning unsupervised-learning

Let Go of Your Labels with Unsupervised Transfer

Artyom Gadetsky*, Yulun Jiang*, Maria Brbić

Project page | Paper | BibTeX


This repo contains the source code of 🐢 TURTLE, an unupervised learning algorithm written in PyTorch. 🔥 TURTLE achieves state-of-the-art unsupervised performance on the variety of benchmark datasets. For more details please check our paper Let Go of Your Labels with Unsupervised Transfer (ICML '24).

PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC

The question we aim to answer in our work is how to utilize representations from foundation models to solve a new task in a fully unsupervised manner. We introduce the problem setting of unsupervised transfer and highlight the key differences between unsupervised transfer and other types of transfer. Specifically, types of downstream transfer differ in the amount of available supervision. Given representation spaces of foundation models, (i) supervised transfer, represented as a linear probe, trains a linear classifier given labeled examples of a downstream dataset; (ii) zero-shot transfer assumes descriptions of the visual categories that appear in a downstream dataset are given, and employs them via text encoder to solve the task; and (iii) unsupervised transfer assumes the least amount of available supervision, i.e., only the number of categories is given, and aims to uncover the underlying human labeling of a dataset.



TURTLE is a method that enables fully unsupervised transfer from foundation models. The key idea behind our approach is to search for the labeling of a downstream dataset that maximizes the margins of linear classifiers in the space of single or multiple foundation models to uncover the underlying human labeling. Compared to zero-shot and supervised transfer, unsupervised transfer with TURTLE does not need the supervision in any form. Compared to deep clustering methods, TURTLE does not require task-specific representation learning that is expensive for modern foundation models.

Dependencies

The code is built with the following libraries

To install cuml, you can follow the instructions on this page.

Quick Start

In our paper, we consider 26 vision datasets studied in (Radford et al. 2021) and 9 different foundation models. As a running example, we present the full pipeline to train TURTLE on the CIFAR100 dataset.

  1. Precompute representations and save ground truth labels for the dataset

    python precompute_representations.py --dataset cifar100 --phis clipvitL14
    python precompute_representations.py --dataset cifar100 --phis dinov2 
    python precompute_labels.py --dataset cifar100
  2. Train TURTLE with 2 representation spaces

    python run_turtle.py --dataset cifar100 --phis clipvitL14 dinov2 

    or with the single representation space

    python run_turtle.py --dataset cifar100 --phis clipvitL14
    python run_turtle.py --dataset cifar100 --phis dinov2

The results and the checkpoints will be saved at ./data/results, ./data/task_checkpoints. You can also use --root_dir in all scripts to specify root directory instead of ./data which is used by default.

Data Preparation

Most datasets can be automatically downloaded by running precompute_representations.py and precompute_labels.py. However, some of the datasets require manual downloading. Please check dataset_preparation/data_utils.py for guide to prepare all the datasets used in our paper.

As an example, to prepare pets dataset that is not directly available at torchvision.datasets, one can run:

python dataset_preparation/prepare_pets.py -i ./data/datasets/pets -o ./data/datasets/pets -d

to download and extract the dataset at ./data/datasets/pets.

After downloading the dataset, run the following command to precompute the representations and labels:

python precompute_representations.py --dataset ${DATASET} --phis ${REPRESENTATION}
python precompute_labels.py --dataset ${DATASET}

Datasets and representations covered in this repo:

Running TURTLE

Once the representations and labels are precomputed, to train TURTLE with a single space, run:

python run_turtle.py --dataset ${DATASET} --phis ${REPRESENTATION} 

or to train TURTLE with multiple representation spaces, run

python run_turtle.py --dataset ${DATASET} --phis ${REPRESENTATION1} ${REPRESENTATION2}

You can also use --inner_lr, ---outer_lr, --warm_start to specify inner step size, outer step size and whether to use cold-start or warm start bilevel optimization. Furthermore, use --cross_val to compute the generalization score for the found labeling after training. You can perform hyperparameter sweep and use the generalization score to select the best hyperparemeters without using ground truth labels.

Pre-trained Checkpoints

We also release the labelings found by TURTLE for all datasets and all model architectures used in our paper. To download pre-trained checkpoints, run:

wget https://brbiclab.epfl.ch/wp-content/uploads/2024/06/turtle_tasks.zip
unzip turtle_tasks.zip

Then, you can evaluate the pre-trained checkpoint of TURTLE with the single space by running:

python evaluate.py --dataset cifar100 --phis clipvitL14 --task_ckpt {PATH_TO_TURTLE_TASKS}/1space/clipvitL14/cifar100.pt
python evaluate.py --dataset cifar100 --phis dinov2     --task_ckpt {PATH_TO_TURTLE_TASKS}/1space/dinov2/cifar100.pt

or evaluate using two representation spaces using:

python evaluate.py --dataset cifar100 --phis clipvitL14 dinov2 --task_ckpt {PATH_TO_TURTLE_TASKS}/2space/clipvitL14_dinov2/cifar100.pt

Baselines

We also provide implemetation of Zero-shot Transfer with CLIP, Linear Probe and K-Means baselines in the baselines folder. To implement linear probe and K-Means baselines we employ cuml for highly efficient cuda implementations.

Linear Probe

Precompute the representations and then perform linear probe evaluation by running:

python baselines/linear_probe.py --dataset ${DATASET} --phis ${REPRESENTATION}

To select the l2 regularization strength for better performance, run

python baselines/linear_probe.py --dataset ${DATASET} --phis ${REPRESENTATION} --validation

K-Means

Precompute the representations and run K-Means baseline:

python baselines/kmeans.py --dataset ${DATASET} --phis ${REPRESENTATION}

Zero-shot Transfer

Run CLIP zero-shot transfer:

python baselines/clip_zs.py --dataset ${DATASET} --phis ${REPRESENTATION}

Acknowledgements

While developing TURTLE we greatly benefited from the open-source repositories:

Citing

If you find our code useful, please consider citing:

@inproceedings{
    gadetsky2024let,
    title={Let Go of Your Labels with Unsupervised Transfer},
    author={Gadetsky, Artyom and Jiang, Yulun and Brbi\'c, Maria},
    booktitle={International Conference on Machine Learning},
    year={2024},
}