gudovskiy / contextflow

Official PyTorch code for UAI 2024 paper "ContextFlow++: Generalist-Specialist Flow-based Generative Models with Mixed-variable Context Encoding"
https://arxiv.org/abs/2406.00578
4 stars 1 forks source link
anomaly-detection context-aware density-estimation normalizing-flows predictive-maintenance time-series unsupervised-learning variational-method

PWC PWC

ContextFlow++: Generalist-Specialist Flow-based Generative Models with Mixed-Variable Context Encoding

Abstract

Normalizing flow-based generative models have been widely used in applications where the exact density estimation is of major importance. Recent research proposes numerous methods to improve their expressivity. However, conditioning on a context is largely overlooked area in the bijective flow research. Conventional conditioning with the vector concatenation is limited to only a few flow types. More importantly, this approach cannot support a practical setup where a set of context-conditioned (specialist) models are trained with the fixed pretrained general-knowledge (generalist) model.

We propose ContextFlow++ approach to overcome these limitations using an additive conditioning with explicit generalist-specialist knowledge decoupling. Furthermore, we support discrete contexts by the proposed mixed-variable architecture with context encoders. Particularly, our context encoder for discrete variables is a surjective flow from which the context-conditioned continuous variables are sampled. Our experiments on rotated MNIST-R, corrupted CIFAR-10C, real-world ATM predictive maintenance and SMAP unsupervised anomaly detection benchmarks show that the proposed ContextFlow++ offers faster stable training and achieves higher performance metrics.

BibTex Citation

If you like our UAI24 paper or code, please cite it using the following BibTex:

@inproceedings{contextflow,
title={ContextFlow++: Generalist-Specialist Flow-based Generative Models with Mixed-variable Context Encoding},
author={Denis A Gudovskiy and Tomoyuki Okuno and Yohei Nakata},
booktitle={The 40th Conference on Uncertainty in Artificial Intelligence (UAI)},
year={2024},
url={https://openreview.net/forum?id=06nlLSkuuu}
}

Installation

Install all packages with these commands:

conda create -n contextflow python=3.8 -y
conda activate contextflow
conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia
python -m pip install -U -r requirements.txt
ln -s ~/PATHTO/data data
cd contextflow

Datasets

Code Organization

Training & Evaluating Models

python model.py --gpu 0 --dataset cifar10 --coupling conv --action-type train-generalist --fold 0 --clean python model.py --gpu 0 --dataset cifar10 --coupling conv --action-type train-generalist --fold 0 --save-checkpoint python model.py --gpu 0 --dataset cifar10 --coupling conv --action-type train-specialist --fold 0 --enc-emb onehot --enc-type vardeq

python model.py --gpu 0 --dataset atm --coupling trans --action-type train-generalist --fold 0 --save-checkpoint python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb onehot --enc-type uniform --contextflow python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb eye --enc-type uniform --contextflow python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb onehot --enc-type vardeq --contextflow python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb eye --enc-type argmax --contextflow python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb embed --enc-type eyesample --contextflow python model.py --gpu 0 --dataset atm --coupling trans --action-type train-specialist --fold 0 --enc-emb embed --enc-type probsample --contextflow

python model.py --gpu 0 --dataset smap --coupling trans --action-type train-generalist python model.py --gpu 0 --dataset msl --coupling trans --action-type train-generalist python model.py --gpu 0 --dataset smd --coupling trans --action-type train-generalist



## ContextFlow++ Architecture
![ContextFlow++](./images/fig-arch.svg)

## Reference Results for ATM and SMAP datasets:
![Results](./images/fig-tables.svg)