ing-bank / probatus

Validation (like Recursive Feature Elimination for SHAP) of (multiclass) classifiers & regressors and data used to develop them.
https://ing-bank.github.io/probatus
MIT License
132 stars 40 forks source link

Penalty on shap calculation for higher variance #218

Closed markdregan closed 1 year ago

markdregan commented 1 year ago

PR addresses #216

Overall objective: Add penalty to features that has high variance in underlying shap values - when computing feature importance. This will (in theory) encourage selection of features that have more coherency across CV folds.

API design:

Work tasks:

Reviewers: LMK any changes / improvements that can be made.

ReinierKoops commented 1 year ago

Btw, a lot more of the precommit hooks have been added recently to improve code quality. I see this affects code you touched indirectly. I’d like to know your thoughts on this if it is bothersome or not.

markdregan commented 1 year ago

I ran a simulation comparing when the variance penalty is applied vs when it is not. My goal is to stress test f this changed should be merged or not.

Code to run simulation saved in this gist

ShapRFECV, approximate=True
   best_score_without_penalty  std_without_penalty  num_features_without_penalty  best_score_with_penalty  std_with_penalty  num_features_with_penalty  n_samples  n_features  n_informative
0                       0.883                0.021                            45                    0.885             0.021                         35       1193         100             25  >> better score & same std
1                       0.965                0.005                            25                    0.966             0.006                         20       2319         100              4  >> better score & worse std
2                       0.857                0.036                            40                    0.859             0.033                         40        787         100             29  >> better score & better std
3                       0.942                0.008                            95                    0.945             0.010                         75       3749         100             32  >> better score & worse std
4                       0.902                0.015                            90                    0.903             0.019                         95       1104         100             32  >> better score & worse std
5                       0.948                0.010                           100                    0.949             0.007                         95       3531         100             27  >> better score & better std
6                       0.949                0.004                            95                    0.948             0.003                        100       4176         100             20  >> worse score & better std
7                       0.958                0.004                            85                    0.955             0.003                        100       3985         100             35  >> worse score & better std
8                       0.944                0.008                            95                    0.942             0.007                        100       3962         100             22  >> worse score & better std
9                       0.953                0.004                           100                    0.955             0.005                         95       4522         100             25  >> better score & worse std

Some observations.

Running for EarlyStoppingShapRFECV I get the below results:

EarlyStoppingShapRFECV, approximate=True
   best_score_without_penalty  std_without_penalty  num_features_without_penalty  best_score_with_penalty  std_with_penalty  num_features_with_penalty  n_samples  n_features  n_informative
0                       0.953                0.006                            90                    0.951             0.007                         95       3469         100             17  >> worse score & worse std
1                       0.918                0.014                            70                    0.919             0.025                         45       1445         100             28  >> better score & worse std
2                       0.922                0.017                            75                    0.918             0.017                        100       1455         100             18  >> worse score & same std
3                       0.934                0.008                            95                    0.933             0.010                         95       3346         100             39  >> worse score & worse std
4                       0.956                0.013                            35                    0.954             0.015                         85       2530         100             12  >> worse score & worse std
5                       0.959                0.013                           100                    0.964             0.010                         95       4187         100             37  >> better score & better std
6                       0.940                0.014                           100                    0.940             0.014                        100       1703         100             35  >> same score & same std
7                       0.898                0.024                            85                    0.906             0.023                         85       1280         100             33  >> better score & better std
8                       0.923                0.017                            20                    0.929             0.011                         15       3403         100              9  >> better score & better std
9                       0.935                0.017                            95                    0.944             0.013                         80       2208         100             21  >> better score & better std

Observations:

ShapRFECV, approximate=False
   best_score_without_penalty  std_without_penalty  num_features_without_penalty  best_score_with_penalty  std_with_penalty  num_features_with_penalty  n_samples  n_features  n_informative
0                       0.906                0.033                            25                    0.905             0.030                         35        685         100             23  >> worse score & better std
1                       0.952                0.010                            40                    0.954             0.011                         35       2449         100             29  >> better score & worse std
2                       0.909                0.008                             5                    0.909             0.005                         10       3316         100              3  >> same score & better std
3                       0.959                0.004                            40                    0.959             0.005                         55       3771         100             32  >> same score & worse std
4                       0.957                0.009                            20                    0.959             0.009                         20       2317         100             13  >> better score & same std
5                       0.875                0.017                            15                    0.874             0.020                         15       1043         100             15  >> worse score & worse std
6                       0.955                0.005                            60                    0.955             0.007                         40       4323         100             30  >> same score & worse std
7                       0.941                0.016                            40                    0.941             0.016                         40       3454         100             32  >> same score & same std
8                       0.943                0.010                            35                    0.942             0.010                         50       4299         100             29  >> worse score & same std
9                       0.961                0.016                             5                    0.961             0.010                         10       1906         100              4  >> same score & better std

Observations:

markdregan commented 1 year ago

Btw, a lot more of the precommit hooks have been added recently to improve code quality. I see this affects code you touched indirectly. I’d like to know your thoughts on this if it is bothersome or not.

@ReinierKoops - I didn't notice any pre-commit hooks being triggered at all honestly. Personally, I like checks that force code consistency as helps me the submitter align with the repo owners code style preferences before the PR review happens.

ReinierKoops commented 1 year ago

Great work. Please have a look at the output of the GitHub actions. In the mean time I’ll try to review it today or tomorrow. I’d suggest you to have a look at which shows how to contribute (install pre-commit and run it locally together with pytest to speed up the workflow).

markdregan commented 1 year ago

Great work. Please have a look at the output of the GitHub actions. In the mean time I’ll try to review it today or tomorrow. I’d suggest you to have a look at which shows how to contribute (install pre-commit and run it locally together with pytest to speed up the workflow).

Thanks for pointers. Worked through all the errors related to my changes. There are two remaining errors that seem unrelated to my PR. See below:

FAILED tests/feature_elimination/test_feature_elimination.py::test_shap_rfe_randomized_search - ValueError: min_samples_split == 1, must be >= 2.
FAILED tests/feature_elimination/test_feature_elimination.py::test_shap_rfe_randomized_search_cols_to_keep - ValueError: min_samples_split == 1, must be >= 2.
ReinierKoops commented 1 year ago

I ran a simulation comparing when the variance penalty is applied vs when it is not. My goal is to stress test f this changed should be merged or not.

Code to run simulation saved in this gist

ShapRFECV, approximate=True
   best_score_without_penalty  std_without_penalty  num_features_without_penalty  best_score_with_penalty  std_with_penalty  num_features_with_penalty  n_samples  n_features  n_informative
0                       0.883                0.021                            45                    0.885             0.021                         35       1193         100             25  >> better score & same std
1                       0.965                0.005                            25                    0.966             0.006                         20       2319         100              4  >> better score & worse std
2                       0.857                0.036                            40                    0.859             0.033                         40        787         100             29  >> better score & better std
3                       0.942                0.008                            95                    0.945             0.010                         75       3749         100             32  >> better score & worse std
4                       0.902                0.015                            90                    0.903             0.019                         95       1104         100             32  >> better score & worse std
5                       0.948                0.010                           100                    0.949             0.007                         95       3531         100             27  >> better score & better std
6                       0.949                0.004                            95                    0.948             0.003                        100       4176         100             20  >> worse score & better std
7                       0.958                0.004                            85                    0.955             0.003                        100       3985         100             35  >> worse score & better std
8                       0.944                0.008                            95                    0.942             0.007                        100       3962         100             22  >> worse score & better std
9                       0.953                0.004                           100                    0.955             0.005                         95       4522         100             25  >> better score & worse std

Some observations.

  • Technique shows better score 70% of the time. Worse score, 30% of the time.
  • Technique shows better std 50% of the time, worse 40% of the time. (one tie).

Running for EarlyStoppingShapRFECV I get the below results:

EarlyStoppingShapRFECV, approximate=True
   best_score_without_penalty  std_without_penalty  num_features_without_penalty  best_score_with_penalty  std_with_penalty  num_features_with_penalty  n_samples  n_features  n_informative
0                       0.953                0.006                            90                    0.951             0.007                         95       3469         100             17  >> worse score & worse std
1                       0.918                0.014                            70                    0.919             0.025                         45       1445         100             28  >> better score & worse std
2                       0.922                0.017                            75                    0.918             0.017                        100       1455         100             18  >> worse score & same std
3                       0.934                0.008                            95                    0.933             0.010                         95       3346         100             39  >> worse score & worse std
4                       0.956                0.013                            35                    0.954             0.015                         85       2530         100             12  >> worse score & worse std
5                       0.959                0.013                           100                    0.964             0.010                         95       4187         100             37  >> better score & better std
6                       0.940                0.014                           100                    0.940             0.014                        100       1703         100             35  >> same score & same std
7                       0.898                0.024                            85                    0.906             0.023                         85       1280         100             33  >> better score & better std
8                       0.923                0.017                            20                    0.929             0.011                         15       3403         100              9  >> better score & better std
9                       0.935                0.017                            95                    0.944             0.013                         80       2208         100             21  >> better score & better std

Observations:

  • Technique shows better score 50% of the time. Worse score, 50% of the time.
  • Technique shows better std 40% of the time, worse 40% of the time. (two ties).
ShapRFECV, approximate=False
   best_score_without_penalty  std_without_penalty  num_features_without_penalty  best_score_with_penalty  std_with_penalty  num_features_with_penalty  n_samples  n_features  n_informative
0                       0.906                0.033                            25                    0.905             0.030                         35        685         100             23  >> worse score & better std
1                       0.952                0.010                            40                    0.954             0.011                         35       2449         100             29  >> better score & worse std
2                       0.909                0.008                             5                    0.909             0.005                         10       3316         100              3  >> same score & better std
3                       0.959                0.004                            40                    0.959             0.005                         55       3771         100             32  >> same score & worse std
4                       0.957                0.009                            20                    0.959             0.009                         20       2317         100             13  >> better score & same std
5                       0.875                0.017                            15                    0.874             0.020                         15       1043         100             15  >> worse score & worse std
6                       0.955                0.005                            60                    0.955             0.007                         40       4323         100             30  >> same score & worse std
7                       0.941                0.016                            40                    0.941             0.016                         40       3454         100             32  >> same score & same std
8                       0.943                0.010                            35                    0.942             0.010                         50       4299         100             29  >> worse score & same std
9                       0.961                0.016                             5                    0.961             0.010                         10       1906         100              4  >> same score & better std

Observations:

  • Technique shows better score 20% of the time. Worse score, 30% of the time. 50% time same.
  • Technique shows better std 30% of the time, worse 40% of the time. 30% time same.
  • With approximate=False, the num_features selected is more closely aligned with the true n_informative features.

Would you be willing to turn the gist into a short tutorial? You can place it here

ReinierKoops commented 1 year ago

The code looks great! Also, it would be nice if you could have a (small) test in which you compare the two different approaches (or only the newly added one). This can be added here.

ReinierKoops commented 1 year ago

Will review somewhere today or tomorrow. Nice work, thanks!

markdregan commented 1 year ago

Made those changes @ReinierKoops.

ReinierKoops commented 1 year ago

Maybe it’s running on a different version of scikit learn? Would it be possible to output all the parameters of the algo that you have for the test? Also implicit ones.

markdregan commented 1 year ago

Old version of scikit-learn was the issue. Tests passing my side now.

ReinierKoops commented 1 year ago

Awesome, happy it’s confirmed where the problem lies!

ReinierKoops commented 1 year ago

Thanks again, your pr’s are much appreciated!

markdregan commented 1 year ago

Very welcome. Thank you and team.