autogluon / autogluon-fair

Apache License 2.0
1 stars 4 forks source link

Add as_oof init param #18

Open Innixma opened 1 year ago

Innixma commented 1 year ago

Example script:


# Load and train a baseline classifier

from autogluon.tabular import TabularDataset, TabularPredictor
from autogluon.fair import FairPredictor 
from autogluon.fair.utils import group_metrics as gm
train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')
test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')
predictor = TabularPredictor(label='class').fit(train_data=train_data)
predictor.leaderboard()
                  model  score_val  pred_time_val   fit_time  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0   WeightedEnsemble_L2     0.8912       0.099142   8.798374                0.004605           1.105777            2       True         14
1               XGBoost     0.8872       0.022041   1.895144                0.022041           1.895144            1       True         11
2         LightGBMLarge     0.8856       0.022847   0.951653                0.022847           0.951653            1       True         13
3              CatBoost     0.8824       0.013061   4.616093                0.013061           4.616093            1       True          7
4              LightGBM     0.8824       0.016489   0.460558                0.016489           0.460558            1       True          4
5            LightGBMXT     0.8792       0.021510   1.237471                0.021510           1.237471            1       True          3
6      RandomForestEntr     0.8624       0.076648   2.598845                0.076648           2.598845            1       True          6
7       NeuralNetFastAI     0.8616       0.041444  31.895101                0.041444          31.895101            1       True         10
8      RandomForestGini     0.8600       0.114874   1.845957                0.114874           1.845957            1       True          5
9        NeuralNetTorch     0.8588       0.021620  34.150226                0.021620          34.150226            1       True         12
10       ExtraTreesGini     0.8500       0.083104   1.116725                0.083104           1.116725            1       True          8
11       ExtraTreesEntr     0.8456       0.082115   1.132972                0.082115           1.132972            1       True          9
12       KNeighborsUnif     0.7752       0.011650   3.147771                0.011650           3.147771            1       True          1
13       KNeighborsDist     0.7660       0.013648   0.018692                0.013648           0.018692            1       True          2
# not overfit
fpredictor_new = FairPredictor(predictor, train_data, 'sex', as_oof=True)
fpredictor_new.fit(gm.accuracy, gm.demographic_parity, 0.02)
fpredictor_new.plot_frontier()

myplot1

# overfit
fpredictor_old = FairPredictor(predictor, train_data, 'sex')
fpredictor_old.fit(gm.accuracy, gm.demographic_parity, 0.02)
fpredictor_old.plot_frontier()

myplot2