Open adam2392 opened 1 year ago
Ideally this causal tree class allows all the functionality that econml currently supports:
https://github.com/microsoft/EconML/tree/main/econml. Tbh I'm not entirely sure on the differences between dml tree, grf and the causal tree implementation in econml.
IIUC, GRF are just regular trees with:
I don't think we really need to implement this as a high priority since it seems the DML trees, model fit for propensity and then model fit for outcome seems like the better bet anyways right?
WDYT @sampan501
Agreed seems low priority to me
Do you know if the other types of trees need to use more exotic criterions? E.g. in econml, the 'het' and 'mse' are these moment equation solvers. I guess for continuity, we should ideally simplify their implementation and have it ourselves as well, but can the other normal sklearn criterions be used?
I don't see why not? but maybe I missed something.
Not that I know of, but we can add a parameter when building the object
Okay that's good to know. The parameter would just be the normal criterion
keyword argument, which can be 'gini', 'poisson', etc.
Re 'het' and the 'mse' (note this mse is not the same as the MSE currently in sklearn, so we should probably call it something else...) criterion in econml, we'll have to replicate the functionality here: https://github.com/py-why/EconML/blob/main/econml/grf/_criterion.pyx. Doesn't look too bad.
Some API issues to figure out.
In sklearn, we have:
Those are probably the main API we want to "override". In causal land, we want something like:
Econml also exposes a class API for getting confidence intervals. I think this is not necessary and overcomplicates the classes. We should just provide a function to get confidence intervals for the predicted causal effects.
Based on the answers above for causal, we might want to expose API for getting specific types of effects like they do in econml, such as CATE, ATE, marginal_CATE, etc.
T
, Z
, which are checked for depending on the meta-estimator. An error is raised if they are not present, but required e.g. in DML.Just have to verify that at least this works with sklearn's testing function: parametrize_with_checks
. If so, then at least we can be assured that the trees are pretty compatible with the rest of the sklearn codebase.
Open questions are:
W
, which are "controls". It is honestly not super clear to me what the difference between W and X is and moreover, how does this impact the fitting/predictingpredict
for? What does it use score
for?Along the lines of causal trees, adding some notes to be aware of:
I've finished perusing and understanding the code in EconML about GRF and the general GRF paper. At an implementation level, this means we have two distinctly "different" kinds of trees that would approach the problems of:
The first is using GRF, which diverges from the sklearn semantics because this is explicitly a "gradient-based" tree. We are solving local gradients at every split node using the GRFCriterion.
The second is using the DML/DR multiple fitting approaches with a honest regression tree as its basic ingredient.
The GRF implementation is quite a bit more complex and if it were up to us, I would be inclined to leave it out, but to achieve feature parity with econml, we should have a refactored version. The DML/DR approaches are fairly more straightforward as they can be implemented as meta estimators as we have discussed in this issue and #52
Besides the GRF https://arxiv.org/pdf/1510.04342.pdf is a good paper
Is your feature request related to a problem? Please describe. Given any tree model, we can fit a causal tree by expanding the
fit
API to allow a treatment group to be passed in.Describe the solution you'd like Something similar in functionality to https://github.com/microsoft/EconML/blob/main/econml/dml/causal_forest.py, but with way simpler implementation.
Additional context Once this works, we could PR to econml. I think we should be wary and make sure all necessary functionality of causal trees is supported.
Note the instantiation process for a causaltree will be very similar to that of a "sklearn Pipeline", where the tree_model should be instantiated outside of the CausalTree. Just makes a simpler API.
https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html