BiomedSciAI / causallib

A Python package for modular causal inference analysis and model evaluations
Apache License 2.0
728 stars 97 forks source link

Causal model selection #45

Closed ehudkr closed 2 years ago

ehudkr commented 2 years ago

Adds the ability to perform model selection on causal models using cross validation. The general idea is to utilize scikit-learn's already established framework for model selection using cross validation with metrics/scorers and extends it to causallib's use.

This PR has three main additions:

  1. Causal metrics and scorers. Mainly, taking some of the graphical metrics described Shimoni et al. 2019 and are available in evaluation module in this package, and quantifying them so they could be used for automated scoring (for example, post-weighting ROC AUC should be as close to the chance AUC of 0.5). Coding these as metrics and also providing them with a scikit-learn scorer API (available in the get_scorers() function).

  2. Hyperparameter search models. A causallib version of GridSearchCV and RandomizedSearchCV that is fully compatible with causallib's API. But also a way to dynamically wrap any other scikit-learn-compatible search models and make them causallib-compatible with the causalize_searcher() function.

  3. K-fold objects for cross validation Specifically, objects to create stratified folds based on either the treatment or both the treatment and outcome (applicable if outcome type is of classes).

Tying it all together, we can auto-tune causal models, as such:

from causallib.model_selection import GridSearchCV
from causallib.model_selection import TreatmentStratifiedKFold
from causallib.datasets import load_nhefs
from causallib.estimation import IPW
from sklearn.linear_model import LogisticRegression

data = load_nhefs()
ipw = IPW(LogisticRegression())
cv = TreatmentStratifiedKFold(n_splits=10)
param_grid = dict(
    clip_min=[0.01, 0.1],  # IPW hyperparameters
    learner__C=[0.001, 0.01, 0.1],  # LogisticRegrssion hyperparameters
)

grid_model = GridSearchCV(
    ipw,
    param_grid=param_grid,
    scoring="weighted_roc_auc_error",  # A scorer for IPW
    cv=cv,
)
grid_model.fit(data.X, data.a, data.y)
potential_outcomes = grid_model.estimate_population_outcome(data.X, data.a, data.y)

Thanks to @mmdanziger for reviewing