py-why / causaltune

AutoML for causal inference.
Apache License 2.0
201 stars 29 forks source link

FLAML-compatible econML estimator classes & metrics #3

Closed TimoFlesch closed 2 years ago

TimoFlesch commented 2 years ago

FLAML lets the user specify which metric to use and what kind of estimators to optimise:

settings = {
    "time_budget": 10,  # total running time in seconds
    "metric": 'accuracy', 
    "estimator_list": ['RGF', 'lgbm', 'rf', 'xgboost'],  # list of ML learners
    "task": 'classification',  # task type    
}
automl.fit(X_train = X_train, y_train = y_train, **settings)

In order to run FLAML's automl on econML models (i.e. to select among DML, metalearners etc), we need to supply it with a custom metric and a list of custom estimators.

Describe the solution you'd like ERUPT metric implemented in automl compatible format & econML estimators implemented in automl compatible format. The usage would be sth like this below:

automl = AutoML()
automl.add_learner(learner_name='lindml', learner_class=LinearDML)
automl.add_learner(learner_name='tLearner', learner_class=Tlearner)
settings = {
    "time_budget": 10,  # total running time in seconds
    "metric": 'ERUPT', 
    "estimator_list": ['lindml','tlearner'],  # list of ML learners
    "task": 'causalinference',  # task type    
}
automl.fit(X_train = X_train, y_train = y_train, **settings)

Note: The add_learner method is already part of automl. Ideally, we'd later on write a wrapper that instantiates the automl and adds all of these learners under the hood.

TimoFlesch commented 2 years ago

I'll start with the base classes and report back once they are implemented :)

TimoFlesch commented 2 years ago

quick update: we decided to work with the low level API (.tune method) first, as the automl class mentioned above requires scikit-learn style estimators and a few other quite major changes to the methods provided by dowhy/econml https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function

TimoFlesch commented 2 years ago

Working on this in this branch. Outstanding issues:

Should be done by next week