FutureXiang / soda

Unofficial implementation of "SODA: Bottleneck Diffusion Models for Representation Learning"
64 stars 2 forks source link
autoencoders classification diffusion diffusion-models generative-models self-supervised-learning

Unofficial Implementation for SODA: Bottleneck Diffusion Models for Representation Learning

This is a multi-gpu PyTorch implementation of the paper SODA: Bottleneck Diffusion Models for Representation Learning:

@article{hudson2023soda,
  title={SODA: Bottleneck Diffusion Models for Representation Learning},
  author={Hudson, Drew A and Zoran, Daniel and Malinowski, Mateusz and Lampinen, Andrew K and Jaegle, Andrew and McClelland, James L and Matthey, Loic and Hill, Felix and Lerchner, Alexander},
  journal={arXiv preprint arXiv:2311.17901},
  year={2023}
}

:exclamation: Please refer to https://github.com/dorarad/soda for the authors' official repository.

:exclamation: Note that this implementation only cares about the linear-probe classification performance, and somewhat ignores other generative downstream tasks. However, this could be a good start for further development. Please check out this DDAE repo, which is the "unconditional" baseline in the SODA paper, if you are also interested in diffusion-based classification.

:exclamation: This repo only contains configs and experiments on small or medium scale datasets such as CIFAR-10/100 and Tiny-ImageNet. Full re-implementation on ImageNet-1k would be extremely expensive.

Requirements

In addition to PyTorch environments, please install:

conda install pyyaml
pip install ema-pytorch tensorboard

Issues, different implementations, and TODOs

Issues

Different implementations

TODOs

Main results

Model Dataset Resolution Epochs #Params K-NN acc Linear probe acc
Res18+DDPM CIFAR-10 32x32 800 11+40 80.4 80.0
Res18+DDPM CIFAR-100 32x32 800 11+40 51.4 54.9
Res18+DDPM Tiny-ImageNet 64x64 800 11+40 34.8 38.2

Usage

Use 4 GPUs to train SODA = resnet18 + DDPM on CIFAR-10/100 classification:

python -m torch.distributed.launch --nproc_per_node=4
  train.py  --config config/cifar10.yaml  --use_amp
  train.py  --config config/cifar100.yaml --use_amp

Use more GPUs to train SODA = resnet18 + DDPM on Tiny-ImageNet classification:

python -m torch.distributed.launch --nproc_per_node=8
  train.py  --config config/tiny.yaml     --use_amp

During training, the SODA encoder is evaluated by a K-NN classifier or linear probing every 100 epoch. Typically, we should not evaluate checkpoints on the validation set like this, but here we just want to observe and get a better understanding of SODA.

To evaluate the final checkpoint after training, run:

python 
  test.py   --config config/cifar10.yaml  --use_amp
  test.py   --config config/cifar100.yaml --use_amp
  test.py   --config config/tiny.yaml     --use_amp