Closed rugrag closed 2 years ago
Hi, a couple things: First, make sure that you're using the best hyperparams for the ERM baseline, as detailed in the appendix of the paper, which include a learning rate of 1e-3 and weight decay 1e-4. Second, for CUB, we report an adjusted average accuracy, as done in the Group DRO paper, where we calculate the average test accuracy over each group and then report a weighted average, with weights corresponding to the relative proportion of each group in the training dataset. Note we do not do this for the other datasets. Let me know if there are still issues!
Hi, thanks for your answer. I tried all the different pairs (learning rate, weight decay) as detailed in the supplementary material. I ran the ERM experiment using the bash script that I reported. When I check at the generated log.txt file, I see test accuracy which are much lower than expected, for example:
Epoch [50]:
...
Test:
Average incurred loss: 0.479
Average sample loss: 0.478
Average acc: 0.732
waterbird_complete95 = 0, forest2water2 = 0 [n = 2255]: loss = 0.103 exp loss = 0.097 adjusted loss = 0.097 adv prob = 0.250000 acc = 0.997
waterbird_complete95 = 0, forest2water2 = 1 [n = 2255]: loss = 0.655 exp loss = 0.682 adjusted loss = 0.682 adv prob = 0.250000 acc = 0.594
waterbird_complete95 = 1, forest2water2 = 0 [n = 642]: loss = 1.289 exp loss = 1.307 adjusted loss = 1.307 adv prob = 0.250000 acc = 0.148
waterbird_complete95 = 1, forest2water2 = 1 [n = 642]: loss = 0.371 exp loss = 0.358 adjusted loss = 0.358 adv prob = 0.250000 acc = 0.872
Current lr: 0.000010
I checked the values at different epochs but I never reached the average test accuracy of 97.3%. I highlighted in bold the result which is (as expected) a weighted average of the accuracies on the single groups, i.e.
(2255 0.997 + 2255 0.594 + 642 0.148 + 642 0.872)/(2255 + 2255 + 642 + 642) = 0.732
I don't know if I misunderstood something about how to read the results
Hi @rugrag, thanks for updating the issue. It looks like the lr for that run you just pasted is 1e-5, can you post your results for your run with the paper's ERM hyperparameters (lr 1e-3 and weight decay 1e-4)? Also, it seems like you might have a changed version of the code, as results on the test set should not be logged to the log.txt file to avoid peeking. Can you make sure that you have the most up-to-date version of the code? Thanks!
I ran again with lr 1e-3 and weight decay 1e-4 but got similar results: ~87% accuracy on average test set and ~68% accuracy on worst group. In the code I only changed the logging function in order to visualize accuracies at the end of every epoch, so I haven't changed the logic or the architecture of the model.
I upload the generated log file. Please have a look to check whether I am missing something
Also on celebA I got lower results of the ERM baseline: 95.6% avg accuracy on test set (as in the paper) but ~37% accuracy on worst group.
Hi @rugrag, thanks for updating! The generated log.txt looks normal to me, and actually the numbers look about as expected as well. Apologies for not being clear before, but for the average accuracy on CUB, we report a weighted average where the weights correspond to the relative proportion of each group in the training & val dataset, not the test dataset. Explicitly, this is: (3498 (landbird, land) + 184 (landbird, water) + 56 (waterbird, land) + 1057 (waterbird, water)) / 4795. So the numbers in the log.txt give an average acc of ~97%, as expected, and although your worst group accuracy is a little lower than our reported value, it seems reasonable and comparable to what other papers have reported.
For your CelebA results, if possible can you upload the generated log.txt? Note that since the worst group accuracy tends to be very low and may fluctuate between epochs, it's possible that there's some variance in the worst group accuracy that's obtained, subject to random seeds, etc. While we obtained a higher worst group accuracy when running, your value is close to what is reported in the Group DRO paper.
Okok thanks, now I got it. So, basically what you reported in Table 1 as Avg Acc. is the weighted average of the accuracies on the test set with weights corresponding to proportions in training set.
Does Worst-group Acc. represent the accuracy of the pair (class, environment) with less samples in the training set, e.g (Blond_Hair=1, Male=1) in celebA and (waterbird_complete95=1, forest2water2=0) in CUB? or am I getting it wrong?
Here it is the log file from celebA
Yes exactly, for CUB, the Avg Acc reported is the weighted average of the accuracies on the test set with weights corresponding to proportions in training set.
Worst-group Acc. corresponds to the accuracy of the group (class, spurious feature) with the lowest accuracy. It does not have to be the group with the fewest number of samples in the training set.
For CelebA, that run looks like it was using an lr of 1e-5 and weight decay of 0.1. The best hyperparams for ERM for CelebA (stated in Appendix A) are lr 1e-4 with weight decay 1e-4. Can you try rerunning with those hyperparams? Thanks!
Hi, @anniesch , many thanks for such great work!
I have one question for the definition of Worst-group: it corresponds to the group (class, spurious feature) with the lowest accuracy. That is to say, it can be (waterbird_complete95 = 0, forest2water2 = 0) for waterbirds, even if the number of training instances in this group is the largest one. Am I getting wrong?
Thanks!
Hi @xieyxclack, yes that is right, the worst-group accuracy for a run corresponds to the group with the lowest accuracy, even if that is not a minority group. Hope that helps, let me know if you have any other questions!
@anniesch Got it. Thanks for your help :)
Hi there, I tried to reproduce the baseline results on CUB as reported in the paper: Avg accuracy 97.3%, Worst group accuracy 72.6% I ran the following commands:
~$ python3 generate_downstream.py --exp_name CUB_sample_exp_v2 --dataset CUB --n_epochs 300 --lr 1e-5 --weight_decay 1.0 --method ERM
~$ bash results/CUB/CUB_sample_exp/ERM_upweight_0_epochs_300_lr_1e-05_weight_decay_1.0/job.sh
The highest accuracy on the average test set that I get is ~84%. Am I doing something wrong? Or maybe I understood the results in the table in a wrong way?
Thank you for your clarification