Normalizing flow-based generative models have been widely used in applications where the exact density estimation is of major importance. Recent research proposes numerous methods to improve their expressivity. However, conditioning on a context is largely overlooked area in the bijective flow research. Conventional conditioning with the vector concatenation is limited to only a few flow types. More importantly, this approach cannot support a practical setup where a set of context-conditioned (specialist) models are trained with the fixed pretrained general-knowledge (generalist) model.
We propose ContextFlow++ approach to overcome these limitations using an additive conditioning with explicit generalist-specialist knowledge decoupling. Furthermore, we support discrete contexts by the proposed mixed-variable architecture with context encoders. Particularly, our context encoder for discrete variables is a surjective flow from which the context-conditioned continuous variables are sampled. Our experiments on rotated MNIST-R, corrupted CIFAR-10C, real-world ATM predictive maintenance and SMAP unsupervised anomaly detection benchmarks show that the proposed ContextFlow++ offers faster stable training and achieves higher performance metrics.
If you like our UAI24 paper or code, please cite it using the following BibTex:
@inproceedings{contextflow,
title={ContextFlow++: Generalist-Specialist Flow-based Generative Models with Mixed-variable Context Encoding},
author={Denis A Gudovskiy and Tomoyuki Okuno and Yohei Nakata},
booktitle={The 40th Conference on Uncertainty in Artificial Intelligence (UAI)},
year={2024},
url={https://openreview.net/forum?id=06nlLSkuuu}
}
Install all packages with these commands:
conda create -n contextflow python=3.8 -y
conda activate contextflow
conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia
python -m pip install -U -r requirements.txt
ln -s ~/PATHTO/data data
cd contextflow
python model.py --gpu 0 --dataset mnist --coupling conv --action-type train-generalist --fold 0 --clean
python model.py --gpu 0 --dataset mnist --coupling conv --action-type train-generalist --fold 0 --save-checkpoint
python model.py --gpu 0 --dataset mnist --coupling conv --action-type train-specialist --fold 0 --enc-emb eye --enc-type uniform
python model.py --gpu 0 --dataset mnist --coupling conv --action-type train-specialist --fold 0 --enc-emb eye --enc-type uniform --contextflow
python model.py --gpu 0 --dataset cifar10 --coupling conv --action-type train-generalist --fold 0 --clean python model.py --gpu 0 --dataset cifar10 --coupling conv --action-type train-generalist --fold 0 --save-checkpoint python model.py --gpu 0 --dataset cifar10 --coupling conv --action-type train-specialist --fold 0 --enc-emb onehot --enc-type vardeq
python model.py --gpu 0 --dataset atm --coupling trans --action-type train-generalist --fold 0 --save-checkpoint python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb onehot --enc-type uniform --contextflow python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb eye --enc-type uniform --contextflow python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb onehot --enc-type vardeq --contextflow python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb eye --enc-type argmax --contextflow python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb embed --enc-type eyesample --contextflow python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb embed --enc-type probsample --contextflow
python model.py --gpu 0 --dataset smap --coupling trans --action-type train-generalist python model.py --gpu 0 --dataset msl --coupling trans --action-type train-generalist python model.py --gpu 0 --dataset smd --coupling trans --action-type train-generalist
## ContextFlow++ Architecture
![ContextFlow++](./images/fig-arch.svg)
## Reference Results for ATM and SMAP datasets:
![Results](./images/fig-tables.svg)