anniesch / jtt

Code for "Just Train Twice: Improving Group Robustness without Training Group Information"
66 stars 16 forks source link

Just Train Twice: Improving Group Robustness without Training Group Information

This code implements the following paper:

Just Train Twice: Improving Group Robustness without Training Group Information

Environment

Create an environment with the following commands:

virtualenv venv -p python3
source venv/bin/activate
pip install -r requirements.txt

Downloading Datasets

Running our Method

Monitoring Performance

Running ERM, Joint DRO, or Group DRO

Adding other datasets

Add the following:

Sample Commands for running JTT on Waterbirds

python generate_downstream.py --exp_name CUB_sample_exp --dataset CUB --n_epochs 300 --lr 1e-5 --weight_decay 1.0 --method ERM

bash results/CUB/CUB_sample_exp/ERM_upweight_0_epochs_300_lr_1e-05_weight_decay_1.0/job.sh

python process_training.py --exp_name CUB_sample_exp --dataset CUB --folder_name ERM_upweight_0_epochs_300_lr_1e-05_weight_decay_1.0 --lr 1e-05 --weight_decay 1.0 --final_epoch 60 --deploy

bash results/CUB/CUB_sample_exp/train_downstream_ERM_upweight_0_epochs_300_lr_1e-05_weight_decay_1.0/final_epoch50/JTT_upweight_100_epochs_300_lr_1e-05_weight_decay_1.0/job.sh

python analysis.py --exp_name CUB_sample_exp/train_downstream_ERM_upweight_0_epochs_300_lr_1e-05_weight_decay_1.0/final_epoch60/ --dataset CUB

Sample Commands for running JTT on CelebA

python generate_downstream.py --exp_name CelebA_sample_exp --dataset CelebA --n_epochs 50 --lr 1e-5 --weight_decay 0.1 --method ERM

bash results/CelebA/CelebA_sample_exp/ERM_upweight_0_epochs_50_lr_1e-05_weight_decay_0.1/job.sh

python process_training.py --exp_name CelebA_sample_exp --dataset CelebA --folder_name ERM_upweight_0_epochs_50_lr_1e-05_weight_decay_0.1 --lr 1e-05 --weight_decay 0.1 --final_epoch 1 --deploy

sbatch results/CelebA/CelebA_sample_exp/train_downstream_ERM_upweight_0_epochs_50_lr_1e-05_weight_decay_0.1/final_epoch1/JTT_upweight_50_epochs_50_lr_1e-05_weight_decay_0.1/job.sh

python analysis.py --exp_name CelebA_sample_exp/train_downstream_ERM_upweight_0_epochs_50_lr_1e-05_weight_decay_0.1/final_epoch1/ --dataset CelebA

Sample Commands for running JTT on MultiNLI

python generate_downstream.py --exp_name MultiNLI_sample_exp --dataset MultiNLI --n_epochs 5 --lr 2e-5 --weight_decay 0 --method ERM

bash results/MultiNLI/MultiNLI_sample_exp/ERM_upweight_0_epochs_5_lr_2e-05_weight_decay_0.0/job.sh

python process_training.py --exp_name MultiNLI_sample_exp --dataset MultiNLI --folder_name ERM_upweight_0_epochs_5_lr_2e-05_weight_decay_0.0 --lr 1e-05 --weight_decay 0.1 --final_epoch 2 --deploy

bash results/MultiNLI/MultiNLI_sample_exp/train_downstream_ERM_upweight_0_epochs_5_lr_2e-05_weight_decay_0.0/final_epoch2/JTT_upweight_4_epochs_5_lr_2e-05_weight_decay_0.1/job.sh

python analysis.py --exp_name MultiNLI_sample_exp/train_downstream_ERM_upweight_0_epochs_5_lr_2e-05_weight_decay_0.0/final_epoch2/ --dataset MultiNLI

Sample Commands for running JTT on CivilComments-WILDS

python generate_downstream.py --exp_name jigsaw_sample_exp --dataset jigsaw --n_epochs 3 --lr 2e-5 --weight_decay 0 --method ERM --batch_size 24

bash results/jigsaw/jigsaw_sample_exp/ERM_upweight_0_epochs_3_lr_2e-05_weight_decay_0.0/job.sh

python process_training.py --exp_name jigsaw_sample_exp --dataset jigsaw --folder_name ERM_upweight_0_epochs_3_lr_2e-05_weight_decay_0.0 --lr 1e-05 --weight_decay 0.01 --final_epoch 2 --batch_size 16 --deploy

bash results/jigsaw/jigsaw_sample_exp/train_downstream_ERM_upweight_0_epochs_3_lr_2e-05_weight_decay_0.0/final_epoch2/JTT_upweight_6_epochs_3_lr_1e-05_weight_decay_0.01/job.sh

python analysis.py --exp_name jigsaw_sample_exp/train_downstream_ERM_upweight_0_epochs_3_lr_2e-05_weight_decay_0.0/final_epoch2/ --dataset jigsaw