PolinaKirichenko / deep_feature_reweighting

BSD 2-Clause "Simplified" License
105 stars 12 forks source link

Base Model Checkpoints #2

Open Haoxiang-Wang opened 1 year ago

Haoxiang-Wang commented 1 year ago

Dear authors,

Could you provide the checkpoints (i.e., saved weights) of the base models used in your paper? I run your commands on CelebA & Waterbirds (for 5 random seeds), and the performance of base models & DFR on top of these base models is slightly worse than that reported in your paper. Thus, I want to request your trained base models for an exact reproduction & further comparison. I would highly appreciate it if you could provide a downloadable link to a Dropbox/Box/Google Drive folder containing your trained models. Thanks!

@andrewgordonwilson @PolinaKirichenko @izmailovpavel

sanyalsunny111 commented 1 year ago

Hey Authors, @PolinaKirichenko

Can you please provide us with the checkpoints to match the results shown in the paper? By checkpoints, I mean both the pretrained ones and the final ones with DFR.

izmailovpavel commented 1 year ago

Hey @Haoxiang-Wang, @sanyalsunny111!

I re-ran 5 checkpoints for Waterbirds and CelebA and uploaded them to this google drive.

The results are the following:

Note that I used the newer repo here: spurious_feature_learning.

The DFR commands:

python3 dfr_evaluate_spurious.py --data_dir=/datasets/CelebA/ --data_transform=AugWaterbirdsCelebATransform --dataset=SpuriousCorrelationDataset --model=imagenet_resnet50_pretrained --ckpt_path=logs/celeba/erm_seed1/final_checkpoint.pt --result_path=celeba_erm_seed1_dfr.pkl --save_linear_model

python3 dfr_evaluate_spurious.py --data_dir=/datasets/waterbirds/ --data_transform=AugWaterbirdsCelebATransform --dataset=SpuriousCorrelationDataset --model=imagenet_resnet50_pretrained --ckpt_path=logs/waterbirds/erm_seed1/final_checkpoint.pt --result_path=wb_erm_seed1_dfr.pkl --save_linear_model

In the google drive, each dataset has 5 subfolders, and each of those contains the base model checkpoint, training logs, command used to train it and the last layer checkpoint trained by DFR.

Please let me know if you have issues with these checkpoints.

sanyalsunny111 commented 1 year ago

@Haoxiang-Wang Thank you very much

sanyalsunny111 commented 1 year ago

@Haoxiang-Wang, I am a bit confused with the evaluation.

I have used a celeba ckpt provide by you and used this script for evaluation

python3 dfr_evaluate_spurious.py --data_dir=./data/celebA_v1.0/ --ckpt_path=dfr-ckpts/celeba/erm_seed1/final_checkpoint.pt --result_path=celeba_erm_seed1_dfr.pkl

I see multiple accuracies Can you please confirm which result to check the worst group accuracy?

image

izmailovpavel commented 1 year ago

Hi @sanyalsunny111, if you want to get the results for DFR_Val, you should be looking at the results under "DFR on Validation", and then test_worst_acc, which in your screenshot is 87.7%. Note again that in my previous post I provided results achieved with the updated repo here, and the commands are for that repo. You should be able to get similar results with this repo too though.