anniesch / jtt

Code for "Just Train Twice: Improving Group Robustness without Training Group Information"
66 stars 16 forks source link

Worst group on CivilComments-WILDS #4

Closed rugrag closed 2 years ago

rugrag commented 3 years ago

Hi there, To me is not 100% clear what you mean by worst group in CivilComments-WILDS: in this case you have 2 classes and 8 environments, hence 16 (overlapping) groups. In the code it seems that the group label is just a copy of the class label. You stated in Section 5.1 "our group DRO minimizes worst-group loss over 4 groups (y,a), where the spurious attribute a is a binary indicator of whether any demographic identity is mentioned and the label y is toxic or non-toxic". Could you please clarify my doubts?

Thanks

anniesch commented 3 years ago

Hi @rugrag, I think some of the confusion may be coming from the fact that we set the --confounder_name flag in generate_downstream.py, not in the sample commands in the README. Group label isn’t just a copy of the class label, when the confounder name is set to be identity_any in generate_downstream.py. This is what we use for group DRO (this gives 4 groups total, taking identity_any x toxic label). Separately, for evaluation, we evaluate on 16 overlapping groups, which are different from the groups that group DRO uses during training. These are 16 overlapping groups from following the WILDS paper evaluation, identity x toxic label. We can’t directly train on the 16 overlapping groups for group DRO, since it’s not built to handle overlapping groups. Hence, we train on the 4 non-overlapping groups for group DRO, similar to what the WILDS paper does. Hopefully this helps clarify things!

rugrag commented 3 years ago

Thanks for your clarification on group DRO :+1: I still have doubts regarding what you evaluated on for JTT (Avg Acc. and Worst Group Acc.). Running your scripts I get the following decomposition for Civil-comments:

Training Data... toxicity = 0: n = 238523 toxicity = 1: n = 30515 Validation Data... toxicity = 0: n = 40125 toxicity = 1: n = 5055 Test Data... toxicity = 0: n = 118558 toxicity = 1: n = 15224

which seems to point out that the Worst Group is represented by the class 1 (whatever attribute a): is this correct? In fact for waterbirds there were 4 groups and the worst group is represented by (waterbird_complete95 = 1, forest2water2 = 0):

Training Data... waterbird_complete95 = 0, forest2water2 = 0: n = 3498 waterbird_complete95 = 0, forest2water2 = 1: n = 184 waterbird_complete95 = 1, forest2water2 = 0: n = 56 waterbird_complete95 = 1, forest2water2 = 1: n = 1057 Validation Data... waterbird_complete95 = 0, forest2water2 = 0: n = 467 waterbird_complete95 = 0, forest2water2 = 1: n = 466 waterbird_complete95 = 1, forest2water2 = 0: n = 133 waterbird_complete95 = 1, forest2water2 = 1: n = 133 Test Data... waterbird_complete95 = 0, forest2water2 = 0: n = 2255 waterbird_complete95 = 0, forest2water2 = 1: n = 2255 waterbird_complete95 = 1, forest2water2 = 0: n = 642 waterbird_complete95 = 1, forest2water2 = 1: n = 642

Could you please clarify this?

anniesch commented 2 years ago

HI @rugrag, sorry for the delay in responding! So we actually used a separate script to calculate the worst group accuracy for Civil-Comments that used the epoch csv files directly. I think the civil_comments_analysis.py file should do the analysis based off of those files, but please let me know if it does not work or output the right thing!