py-why / EconML

ALICE (Automated Learning and Intelligence for Causation and Economics) is a Microsoft Research project aimed at applying Artificial Intelligence concepts to economic decision making. One of its goals is to build a toolkit that combines state-of-the-art machine learning techniques with econometrics in order to bring automation to complex causal inference problems. To date, the ALICE Python SDK (econml) implements orthogonal machine learning algorithms such as the double machine learning work of Chernozhukov et al. This toolkit is designed to measure the causal effect of some treatment variable(s) t on an outcome variable y, controlling for a set of features x.
https://www.microsoft.com/en-us/research/project/alice/
Other
3.82k stars 716 forks source link

OrthoForest speed and implementation details #743

Open adam2392 opened 1 year ago

adam2392 commented 1 year ago

Hi,

I'm wondering why is OrthoForest implemented in purely Python it seems?

https://github.com/py-why/EconML/blob/d1baba1e676d5fa8c90b33bf21d806964ae12ec9/econml/orf/_causal_tree.py#L53

vs leveraging the existing Cythonized trees?

vsyrgkanis commented 1 year ago

I know I wish it could benefit from cython. But the orthoforest, decides its splits by fitting at every splitting decision a scikit-learn model (e.g. a Lasso). So at every split point we need to call Lasso. The moment you call a function that goes outside of pure cython, all the benefits of cython vanish and all the delay is in the back and forth between cython and python and waiting for locks to be released.

I tried it once to see how fast it could go, but it wasn't faster than the pure python implementation that does other tricks to speed up things.

A more detailed/subtle answer:

One could thing of alternative variants of the ortho forest, where the trees are built by a-priori residualization, and "local residualization" is only used in the final stage. So a combination of what we do in causalforestdml and for tree construction and what we do in orthoforest for final CATE model estimation.

This variant, could have the benefits of both worlds and it would definitely be much faster in cython. In fact it can just use the existing cython implementation in grf for the tree construction and only change what happens in the final CATE prediction stage.

adam2392 commented 1 year ago

I know I wish it could benefit from cython. But the orthoforest, decides its splits by fitting at every splitting decision a scikit-learn model (e.g. a Lasso). So at every split point we need to call Lasso. The moment you call a function that goes outside of pure cython, all the benefits of cython vanish and all the delay is in the back and forth between cython and python and waiting for locks to be released.

Ah yeah looking at the code further, I see. It's unfortunate this is not properly described/documented, or housed in an experimental submodule. Trees are unfortunately pretty not used unless implemented in Cython/C++ (similar to how NNs are pretty much not used w/o a GPU, unless for prediction only)

I tried it once to see how fast it could go, but it wasn't faster than the pure python implementation that does other tricks to speed up things.

Yeah I think the estimation of nuisance and others need to all be compiled code in order for this to work.

One could thing of alternative variants of the ortho forest, where the trees are built by a-priori residualization, and "local residualization" is only used in the final stage. So a combination of what we do in causalforestdml and for tree construction and what we do in orthoforest for final CATE model estimation.

Do you mind expanding on this? Do you mean fitting a propensity model (e.g. LogisticRegression) and an outcome regression model and then running some forest afterwards?

We have a research team that is heavily invested in tree methods and might be able to implement this and possibly PR this to econml afterwards.

vsyrgkanis commented 1 year ago

Yes this is what I mean! We fit a regression and propensity model and then we define either the "residuals" for the DMLOrthoForest or the doubly robust targets for the DROrthoForest.

Then we fit CausalTrees for DMLOrthoForest based on the residual on residual moment (this is exactly how the trees in CausalForestDML are being constructed)

And for DROrthoForest we fit regression trees predicting the doubly robust target. This is exactly how the trees in the ForestDRLearner method are constructed.

Finally, once we have the tree structures, we can estimate the CATEs using "localized" nuisance estimation as in OrthoForest. This requires fitting a separate nuisance model for every target sample that you want to predict on. This is still slow, but this is unavoidable to get the statistical benefits of the method.

So this would be a matter of combining the "tree construction" that is happening in CausalForestDML and ForestDRLearner, with the final stage estimation that is happening in the corresponding OrthoForest methods.

If that's still of interest, we can have a zoom call at some point.

adam2392 commented 1 year ago

Hey @vsyrgkanis thanks for expanding. Yeah a zoom call would be great!

Leaving my comments here for discussion. So IIUC, whether it is DML, or DR, we want:

The result is a mapping (X, T) (covariates and treatments) to (Y) (outcomes). For an orthogonal forest not doing the local nuisance estimation during tree construction time, we can have an optional step at predict time:

In the second bullet point, I'm confused currently based on reading the econml documentation what is going on there vs what I understand from theory/papers. How are the residuals different in DML vs DR? Maybe I missed this, but is there a paper explaining this difference? From my understanding DML estimators are "doubly robust", so what is a "DR" residual?

In the third stage for ortho forests, even if we have the tree structure, how can we substitute the localized nuisance function?

vsyrgkanis commented 1 year ago

Might be easier in a quick call. The ortho forest paper is here if you want to check out the localized nuisance procedure: https://arxiv.org/abs/1806.03467

DML is not doubly robust, just Neyman orthogonal.

The nuisance quantities that go into the dml method are different than the ones that go into the DR method.

check out section 6 of the paper that describes the differences!

adam2392 commented 1 year ago

Sure a quick call would help! Are you on the pywhy discord?

DML is not doubly robust, just Neyman orthogonal.

I see, I think I mis-read the docs because I thought the DML implementation here was doubly-robust too.