Adds the ability to perform model selection on causal models using cross validation.
The general idea is to utilize scikit-learn's already established framework for model selection using cross validation with metrics/scorers and extends it to causallib's use.
This PR has three main additions:
Causal metrics and scorers.
Mainly, taking some of the graphical metrics described Shimoni et al. 2019 and are available in evaluation module in this package, and quantifying them so they could be used for automated scoring (for example, post-weighting ROC AUC should be as close to the chance AUC of 0.5).
Coding these as metrics and also providing them with a scikit-learn scorer API (available in the get_scorers() function).
Hyperparameter search models.
A causallib version of GridSearchCV and RandomizedSearchCV that is fully compatible with causallib's API.
But also a way to dynamically wrap any other scikit-learn-compatible search models and make them causallib-compatible with the causalize_searcher() function.
K-fold objects for cross validation
Specifically, objects to create stratified folds based on either the treatment or both the treatment and outcome (applicable if outcome type is of classes).
Tying it all together, we can auto-tune causal models, as such:
from causallib.model_selection import GridSearchCV
from causallib.model_selection import TreatmentStratifiedKFold
from causallib.datasets import load_nhefs
from causallib.estimation import IPW
from sklearn.linear_model import LogisticRegression
data = load_nhefs()
ipw = IPW(LogisticRegression())
cv = TreatmentStratifiedKFold(n_splits=10)
param_grid = dict(
clip_min=[0.01, 0.1], # IPW hyperparameters
learner__C=[0.001, 0.01, 0.1], # LogisticRegrssion hyperparameters
)
grid_model = GridSearchCV(
ipw,
param_grid=param_grid,
scoring="weighted_roc_auc_error", # A scorer for IPW
cv=cv,
)
grid_model.fit(data.X, data.a, data.y)
potential_outcomes = grid_model.estimate_population_outcome(data.X, data.a, data.y)
Adds the ability to perform model selection on causal models using cross validation. The general idea is to utilize scikit-learn's already established framework for model selection using cross validation with metrics/scorers and extends it to causallib's use.
This PR has three main additions:
Causal metrics and scorers. Mainly, taking some of the graphical metrics described Shimoni et al. 2019 and are available in
evaluation
module in this package, and quantifying them so they could be used for automated scoring (for example, post-weighting ROC AUC should be as close to the chance AUC of 0.5). Coding these as metrics and also providing them with a scikit-learn scorer API (available in theget_scorers()
function).Hyperparameter search models. A causallib version of
GridSearchCV
andRandomizedSearchCV
that is fully compatible with causallib's API. But also a way to dynamically wrap any other scikit-learn-compatible search models and make them causallib-compatible with thecausalize_searcher()
function.K-fold objects for cross validation Specifically, objects to create stratified folds based on either the treatment or both the treatment and outcome (applicable if outcome type is of classes).
Tying it all together, we can auto-tune causal models, as such:
Thanks to @mmdanziger for reviewing