This code implements the following paper:
Just Train Twice: Improving Group Robustness without Training Group Information
Create an environment with the following commands:
virtualenv venv -p python3
source venv/bin/activate
pip install -r requirements.txt
Waterbirds: Download waterbirds from here and put it in jtt/cub
.
data/waterbird_complete95_forest2water2/
with metadata.csv
inside.CelebA: Download CelebA from here and put it in jtt/celebA
.
MultiNLI: Follow instructions here to download this dataset and put in jtt/multinli
CivilComments: This dataset can be downloaded from here and put it in jtt/jigsaw
. In that directory, our code expects a folder data
with the downloaded dataset.
python generate_downstream.py --exp_name $EXPERIMENT_NAME --dataset $DATASET --method ERM
--n_epochs $EPOCHS --lr $LR --weight_decay $WD
. Other args, e.g. batch size, can be changed in generate_downstream.py.CUB
, CelebA
, MultiNLI
, jigsaw
results/dataset/$EXPERIMENT_NAME
python process_training.py --exp_name $EXPERIMENT_NAME --dataset $DATASET --folder_name $ERM_FOLDER_NAME --lr $LR --weight_decay $WD --deploy
JTT
in their name.python analysis.py --exp_name $PATH_TO_JTT_RUNS --dataset $DATASET
$PATH_TO_JTT_RUNS
will look like $EXPERIMENT_NAME+"/train_downstream_"+$ERM_FOLDER_NAME+"/final_epoch"+$FINAL_EPOCH
python generate_downstream.py --exp_name $EXPERIMENT_NAME --dataset $DATASET --method $METHOD
--n_epochs $EPOCHS --lr $LR --weight_decay $WD
CUB
, CelebA
, MultiNLI
, jigsaw
results/dataset/$EXPERIMENT_NAME
Add the following:
process_training.py
to include the required args for your dataset and implement a way for getting the spurious features from the dataset.python generate_downstream.py --exp_name CUB_sample_exp --dataset CUB --n_epochs 300 --lr 1e-5 --weight_decay 1.0 --method ERM
bash results/CUB/CUB_sample_exp/ERM_upweight_0_epochs_300_lr_1e-05_weight_decay_1.0/job.sh
python process_training.py --exp_name CUB_sample_exp --dataset CUB --folder_name ERM_upweight_0_epochs_300_lr_1e-05_weight_decay_1.0 --lr 1e-05 --weight_decay 1.0 --final_epoch 60 --deploy
bash results/CUB/CUB_sample_exp/train_downstream_ERM_upweight_0_epochs_300_lr_1e-05_weight_decay_1.0/final_epoch50/JTT_upweight_100_epochs_300_lr_1e-05_weight_decay_1.0/job.sh
python analysis.py --exp_name CUB_sample_exp/train_downstream_ERM_upweight_0_epochs_300_lr_1e-05_weight_decay_1.0/final_epoch60/ --dataset CUB
python generate_downstream.py --exp_name CelebA_sample_exp --dataset CelebA --n_epochs 50 --lr 1e-5 --weight_decay 0.1 --method ERM
bash results/CelebA/CelebA_sample_exp/ERM_upweight_0_epochs_50_lr_1e-05_weight_decay_0.1/job.sh
python process_training.py --exp_name CelebA_sample_exp --dataset CelebA --folder_name ERM_upweight_0_epochs_50_lr_1e-05_weight_decay_0.1 --lr 1e-05 --weight_decay 0.1 --final_epoch 1 --deploy
sbatch results/CelebA/CelebA_sample_exp/train_downstream_ERM_upweight_0_epochs_50_lr_1e-05_weight_decay_0.1/final_epoch1/JTT_upweight_50_epochs_50_lr_1e-05_weight_decay_0.1/job.sh
python analysis.py --exp_name CelebA_sample_exp/train_downstream_ERM_upweight_0_epochs_50_lr_1e-05_weight_decay_0.1/final_epoch1/ --dataset CelebA
python generate_downstream.py --exp_name MultiNLI_sample_exp --dataset MultiNLI --n_epochs 5 --lr 2e-5 --weight_decay 0 --method ERM
bash results/MultiNLI/MultiNLI_sample_exp/ERM_upweight_0_epochs_5_lr_2e-05_weight_decay_0.0/job.sh
python process_training.py --exp_name MultiNLI_sample_exp --dataset MultiNLI --folder_name ERM_upweight_0_epochs_5_lr_2e-05_weight_decay_0.0 --lr 1e-05 --weight_decay 0.1 --final_epoch 2 --deploy
bash results/MultiNLI/MultiNLI_sample_exp/train_downstream_ERM_upweight_0_epochs_5_lr_2e-05_weight_decay_0.0/final_epoch2/JTT_upweight_4_epochs_5_lr_2e-05_weight_decay_0.1/job.sh
python analysis.py --exp_name MultiNLI_sample_exp/train_downstream_ERM_upweight_0_epochs_5_lr_2e-05_weight_decay_0.0/final_epoch2/ --dataset MultiNLI
python generate_downstream.py --exp_name jigsaw_sample_exp --dataset jigsaw --n_epochs 3 --lr 2e-5 --weight_decay 0 --method ERM --batch_size 24
bash results/jigsaw/jigsaw_sample_exp/ERM_upweight_0_epochs_3_lr_2e-05_weight_decay_0.0/job.sh
python process_training.py --exp_name jigsaw_sample_exp --dataset jigsaw --folder_name ERM_upweight_0_epochs_3_lr_2e-05_weight_decay_0.0 --lr 1e-05 --weight_decay 0.01 --final_epoch 2 --batch_size 16 --deploy
bash results/jigsaw/jigsaw_sample_exp/train_downstream_ERM_upweight_0_epochs_3_lr_2e-05_weight_decay_0.0/final_epoch2/JTT_upweight_6_epochs_3_lr_1e-05_weight_decay_0.01/job.sh
python analysis.py --exp_name jigsaw_sample_exp/train_downstream_ERM_upweight_0_epochs_3_lr_2e-05_weight_decay_0.0/final_epoch2/ --dataset jigsaw