floft / vrada

Variational Adversarial Deep Domain Adaptation implementation (TensorFlow 1.x)
10 stars 4 forks source link

Variational Recurrent Adversarial Deep Domain Adaptation (VRADA)

Implementation of VRADA in TensorFlow. See their paper or blog post for details about the method. In their 2016 workshop paper, they called this Variational Adversarial Deep Domain Adaptation (VADDA). It's more-or-less the same method though they might do iterative optimization slightly differently.

You have a choice of running with or without domain adaptation and with two types of RNNs. In their paper, they refer to the LSTM with domain adaptation as "R-DANN" and the VRNN with domain adaptation as "VRADA."

To try these out, make sure you clone the repository recursively since there's submodules.

git clone --recursive https://github.com/floft/vrada
cd vrada

Datasets

This method uses RNNs, so requires time-series datasets. See README.md in datasets/ for information about generating some simple synthetic datasets or using an RF sleep stage dataset or the MIMIC-III health care dataset that the VRADA paper used. You can select which dataset to use with a command-line argument:

Usage

Training Locally

For example, to run domain adaption locally using a VRNN on the synthetically generated "trivial" dataset:

python3 VRADA.py --logdir logs --modeldir models --debug --vrnn-da --trivial-line

Note the "--debug" flag tells it to start a new log and model directory (incrementing the folder number each time) for each run rather than continuing from where the previous run left off.

Training on a High-Performance Cluster

Alternatively, training on a cluster with Slurm (in my case on Kamiak) after editing kamiak_config.sh:

sbatch kamiak_train.srun --vrnn-da --trivial-line

Then on your local computer to monitor the progress in TensorBoard:

./kamiak_tflogs.sh
tensorboard --logdir vrada-logs

If you want to see images at more than 10 time steps:

tensorboard --logdir vrada-logs --samples_per_plugin images=100