neurodata / treeple

Scikit-learn compatible decision trees beyond those offered in scikit-learn
https://treeple.ai
Other
60 stars 13 forks source link

Add meta-estimator for causal trees #42

Open adam2392 opened 1 year ago

adam2392 commented 1 year ago

Is your feature request related to a problem? Please describe. Given any tree model, we can fit a causal tree by expanding the fit API to allow a treatment group to be passed in.

Describe the solution you'd like Something similar in functionality to https://github.com/microsoft/EconML/blob/main/econml/dml/causal_forest.py, but with way simpler implementation.


class CausalTree:
   def __init__(tree_model, ...):
        # instantiated tree model
        self.tree_model = tree_model
   def fit(self, X, y, T):
         self.tree_model.fit(X, T)
         self.tree_model.fit(T, y)
         self.tree_model.fit(X, y)
         # combine them to get the estimates for P(y | do(X))

Additional context Once this works, we could PR to econml. I think we should be wary and make sure all necessary functionality of causal trees is supported.

Note the instantiation process for a causaltree will be very similar to that of a "sklearn Pipeline", where the tree_model should be instantiated outside of the CausalTree. Just makes a simpler API.

https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html

adam2392 commented 1 year ago

Ideally this causal tree class allows all the functionality that econml currently supports:

https://github.com/microsoft/EconML/tree/main/econml. Tbh I'm not entirely sure on the differences between dml tree, grf and the causal tree implementation in econml.

adam2392 commented 1 year ago

IIUC, GRF are just regular trees with:

I don't think we really need to implement this as a high priority since it seems the DML trees, model fit for propensity and then model fit for outcome seems like the better bet anyways right?

WDYT @sampan501

sampan501 commented 1 year ago

Agreed seems low priority to me

adam2392 commented 1 year ago

Do you know if the other types of trees need to use more exotic criterions? E.g. in econml, the 'het' and 'mse' are these moment equation solvers. I guess for continuity, we should ideally simplify their implementation and have it ourselves as well, but can the other normal sklearn criterions be used?

I don't see why not? but maybe I missed something.

sampan501 commented 1 year ago

Not that I know of, but we can add a parameter when building the object

adam2392 commented 1 year ago

Okay that's good to know. The parameter would just be the normal criterion keyword argument, which can be 'gini', 'poisson', etc.

adam2392 commented 1 year ago

Re 'het' and the 'mse' (note this mse is not the same as the MSE currently in sklearn, so we should probably call it something else...) criterion in econml, we'll have to replicate the functionality here: https://github.com/py-why/EconML/blob/main/econml/grf/_criterion.pyx. Doesn't look too bad.

adam2392 commented 1 year ago

Some API issues to figure out.

In sklearn, we have:

Those are probably the main API we want to "override". In causal land, we want something like:

Econml also exposes a class API for getting confidence intervals. I think this is not necessary and overcomplicates the classes. We should just provide a function to get confidence intervals for the predicted causal effects.

Based on the answers above for causal, we might want to expose API for getting specific types of effects like they do in econml, such as CATE, ATE, marginal_CATE, etc.

Some possible solutions

Just have to verify that at least this works with sklearn's testing function: parametrize_with_checks. If so, then at least we can be assured that the trees are pretty compatible with the rest of the sklearn codebase.

Questions to Resolve

Open questions are:

  1. how does econml use W, which are "controls". It is honestly not super clear to me what the difference between W and X is and moreover, how does this impact the fitting/predicting
  2. what does econml currently use predict for? What does it use score for?
adam2392 commented 1 year ago

Along the lines of causal trees, adding some notes to be aware of:

I've finished perusing and understanding the code in EconML about GRF and the general GRF paper. At an implementation level, this means we have two distinctly "different" kinds of trees that would approach the problems of:

The first is using GRF, which diverges from the sklearn semantics because this is explicitly a "gradient-based" tree. We are solving local gradients at every split node using the GRFCriterion.

The second is using the DML/DR multiple fitting approaches with a honest regression tree as its basic ingredient.

The GRF implementation is quite a bit more complex and if it were up to us, I would be inclined to leave it out, but to achieve feature parity with econml, we should have a refactored version. The DML/DR approaches are fairly more straightforward as they can be implemented as meta estimators as we have discussed in this issue and #52

adam2392 commented 1 year ago

Besides the GRF https://arxiv.org/pdf/1510.04342.pdf is a good paper