google-research / adamatch

Apache License 2.0
64 stars 5 forks source link

AdaMatch

Code for the paper: "AdaMatch: A Unified Approach to Semi-Supervised Learning and Domain Adaptation" by David Berthelot, Rebecca Roelofs, Kihyuk Sohn, Nicholas Carlini, and Alex Kurakin.

This is not an officially supported Google product.

AdaMatch diagram

Setup

sudo apt install python3-dev python3-virtualenv python3-tk imagemagick
virtualenv -p python3 --system-site-packages ~/jax3
. ~/jax3/bin/activate

# Install dependencies (replace with your installed CUDA version)
CUDA_VERSION=11.2
pip install --upgrade -r requirements.txt
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`

Required environment variables

export ML_DATA="path to where you want the datasets saved"
export PYTHONPATH=$PYTHONPATH:.

Potentially useful environment variables

# Use this config if you won't want JAX to take the whole GPU memory
export XLA_PYTHON_CLIENT_PREALLOCATE=false

Data preparation

# Download datasets and save them as TF records
CUDA_VISIBLE_DEVICES= python scripts/create_datasets.py

# Alternatively if your machine has many cores, this parallel version is much faster to run.
bash ./runs/create_datasets.sh

Training

Options

Specific to Domain Adaptation (DA) and Semi-Supervised Domain Adaptation (SSDA)

Note: for SSDA, the target takes a different form.

Specific to Semi-Supervised Learning (SSL)

Fully supervised

# Baseline: source and target must be the same. Additionally, one can specify extra test sets.
python fully_supervised/baseline.py --dataset=domainnet32 --source=clipart --target=clipart\
    --logdir experiments/2021/02.12-32 --augment=CTA\(sm,sm\)\
    --test_extra=clipart,infograph,quickdraw,real,sketch,painting

Domain adaptation

# Baseline: does nothing for unlabeled except running it with labeled as a single batch through the network.
python domain_adaptation/baseline.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
python domain_adaptation/baseline.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=1 --augment=\(sm,smc\)

# FixMatch
python domain_adaptation/fixmatch_da.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
python domain_adaptation/fixmatch_da.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=1 --augment=\(sm,smc\)

# AdaMatch
python domain_adaptation/adamatch.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
python domain_adaptation/adamatch.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=1 --augment=\(sm,smc\)

# NoisyStudent
## Teacher
python domain_adaptation/noisy_student.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --pseudo_label_th=0.9 --augment=CTA\(sm,sm\)\
    --id=0
## Student
python domain_adaptation/noisy_student.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --pseudo_label_th=0.9 --augment=CTA\(sm,sm\) --id=1\
    --pseudo_label_file=experiments/2021/03.1-32/DA/domainnet32/clipart/quickdraw/CTA\(sm,sm\)/NoisyStudent/archwrn28-2_batch64_lr0.03_lr_decay0.25_wd0.001/0/predictions.npy

# MCD
python domain_adaptation/mcd.py --dataset=domainnet64 --source=clipart --target=quickdraw\
    --uratio=1 --arch=wrn28-2 --augment='CTA(sm,sm)' --train_mimg=8 --logdir experiments/2021/02.12-32 --lr_decay 0.25

Multi-source domain adaptation

Use no in front of the subdomain to use all domains but the one concerned.

python domain_adaptation/adamatch.py --dataset=domainnet32 --source=no_quickdraw --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)

Semi-Supervised Domain adaptation

# FixMatch
python semi_supervised_domain_adaptation/fixmatch_da.py --dataset=domainnet32 --source=clipart\
    --target=quickdraw\(10,seed=1\)\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)

# AdaMatch
python semi_supervised_domain_adaptation/adamatch.py --dataset=domainnet32 --source=clipart\
    --target=quickdraw\(10,seed=1\)\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)

Semi-Supervised Learning

# FixMatch
python semi_supervised/fixmatch_da.py --dataset=domainnet32_quickdraw\(10,seed=1\)\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)

# AdaMatch
python semi_supervised/adamatch.py --dataset=domainnet32_quickdraw\(10,seed=1\)\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)

Tensorboard

tensorboard --logdir experiments