This is a multi-gpu PyTorch implementation of the paper SODA: Bottleneck Diffusion Models for Representation Learning:
@article{hudson2023soda,
title={SODA: Bottleneck Diffusion Models for Representation Learning},
author={Hudson, Drew A and Zoran, Daniel and Malinowski, Mateusz and Lampinen, Andrew K and Jaegle, Andrew and McClelland, James L and Matthey, Loic and Hill, Felix and Lerchner, Alexander},
journal={arXiv preprint arXiv:2311.17901},
year={2023}
}
:exclamation: Please refer to https://github.com/dorarad/soda for the authors' official repository.
:exclamation: Note that this implementation only cares about the linear-probe classification performance, and somewhat ignores other generative downstream tasks. However, this could be a good start for further development. Please check out this DDAE repo, which is the "unconditional" baseline in the SODA paper, if you are also interested in diffusion-based classification.
:exclamation: This repo only contains configs and experiments on small or medium scale datasets such as CIFAR-10/100 and Tiny-ImageNet. Full re-implementation on ImageNet-1k would be extremely expensive.
In addition to PyTorch environments, please install:
conda install pyyaml
pip install ema-pytorch tensorboard
Model | Dataset | Resolution | Epochs | #Params | K-NN acc | Linear probe acc |
---|---|---|---|---|---|---|
Res18+DDPM | CIFAR-10 | 32x32 | 800 | 11+40 | 80.4 | 80.0 |
Res18+DDPM | CIFAR-100 | 32x32 | 800 | 11+40 | 51.4 | 54.9 |
Res18+DDPM | Tiny-ImageNet | 64x64 | 800 | 11+40 | 34.8 | 38.2 |
Use 4 GPUs to train SODA = resnet18 + DDPM
on CIFAR-10/100 classification:
python -m torch.distributed.launch --nproc_per_node=4
train.py --config config/cifar10.yaml --use_amp
train.py --config config/cifar100.yaml --use_amp
Use more GPUs to train SODA = resnet18 + DDPM
on Tiny-ImageNet classification:
python -m torch.distributed.launch --nproc_per_node=8
train.py --config config/tiny.yaml --use_amp
During training, the SODA encoder is evaluated by a K-NN classifier or linear probing every 100 epoch. Typically, we should not evaluate checkpoints on the validation set like this, but here we just want to observe and get a better understanding of SODA.
To evaluate the final checkpoint after training, run:
python
test.py --config config/cifar10.yaml --use_amp
test.py --config config/cifar100.yaml --use_amp
test.py --config config/tiny.yaml --use_amp