neurodata / treeple

Scikit-learn compatible decision trees beyond those offered in scikit-learn
https://treeple.ai
Other
60 stars 13 forks source link

Implement GRF criterion with refactored internals that subclass RegressionCriterion #55

Open adam2392 opened 1 year ago

adam2392 commented 1 year ago

Summary

We need to add GRF capabilities in order to be feature-complete wrt EconML. There are some issues with just doing honesty at the Python level, which is due to the fact that we want to create splits that ensure the leaves are populated with held-out data outcome values and across different sets of treatments. Note this is the strategy we currently take in HonestTreeClassifier.

  1. Splitter should in some way make sure there are enough treatments, T on both sides of a split (if we're estimating treatment effects, then we need to have enough "variability" for some input data on each side of the split
  2. When setting leaf values, we want to make sure the "held-out" dataset in honesty setting has data that actually reaches the said leaf, otw the leaf values.

Overall this complicates matters as honesty implemented in Python would need to have some constraints built in (e.g. limiting the max-depth of the tree).

Cite: https://github.com/py-why/EconML/blob/main/econml/grf/_criterion.pyx

Possible high-level approaches

If we want "leaves" in the tree that are ensured that held-out dataset reaches, then we can just implement a pruning tree algorithm that is a "bit" more computational expense, but then it simplifies our life because we don't need the splitter to be aware of a "held-out" validation dataset.

However, we still want the splitter to create splits that get enough treatments in all treatment categories for each split. We have added some generalization to the existing splitter code to account for this, such as creating a function that defines "pre-split stopping conditions" and "post-split stopping conditions" that subclasses can over-ride.

Alternatively, we just copy-paste some of the existing code into a "CausalSplitter(BaseSplitter)", which adds some extra functionality for maintaining balancedness of treatment groups in split and leaf nodes.

adam2392 commented 10 months ago

Some thoughts here in increasing order of complexity:

Criterion

The criterion we are trying to implement is a generalization that stems from the GRF paper:

# the general form, where m(.) is a possibly non-linear functional,
# J and A are some data-related points, respectively pointJ and alpha
# theta is the parameter we want to estimate (i.e. a split point threshold for a particular feature)
E[ m(J, A; theta(x)) | X=x ] 

# The linear GRF equation then is:
E[ J * theta(x) - A | X=x] = 0

Some examples worked out are the following (which we should include in the documentation eventually):

# Regression, where m(.) is the squared-error function; J = Identity matrix, A is Y
E[ m(J, A; theta(x)) | X=x ]  = E[(theta(x) - Y)^2 | X=x]

# Quantile regressions, where m(.) is the quantile function
E[ m(J, A; theta(x)) | X=x ]  = E[q * I(Y_i > theta) - (1-q) I(Y_i \le theta) | X=x]

# CATE
E[ m(J, A; theta(x)) | X=x ]  = E[J * theta(x) - Y_i @ W_i | X=x] = E[ ((Y_i @ W_i - theta(x) @ W_i - c(x)) @ W_i^T | X=x]

# the last term is the intercept
# W_i^T @ W_i is the jacobian (i.e. pointJ is the cross-product of T and T
=  E[ W_i^T @ Y_i @ W_i - W_i^T @ theta(x) @ W_i - W_i^T c(x) | X=x]

# the IV-CATE is even more complex

The difference between the LinearGRFCriterion and regular scikit-learn Criterion is that i) the criterion is significantly more complex to compute and ii) more importantly, y now consists of a column-stacked set of multivariate variables. Namely:

What are the shapes of y and how it is handled?

# this would be the normal `y` is the target variable i.e. 0 or 1 for classification
y = y

# this would be keeping the target variable and adding columns indicating 0 or 1 for the treatment.
y = np.hstack((y, T))

# also adding now the instrumental variable
y = np.hstack((y, T, Z))

In practice, what they do though is make y to be of shape ((n_samples, n_outputs_y + n_outputs_A + n_outputs_A * n_outputs_A)). In econml: n_outputs_y == n_outputs_. Then n_relevant_outputs_ is the number of columns that we actually care about (i.e. non-nuisance parameters).

(To see the above, search _get_n_outputs_decomposition in EconML codebase)

How are alpha and pointJ defined per tree model we have?

Next, up we also have to pre-define the $\alpha$ and pointJ (i.e. Jacobian evaluated at the y point) value:

In causal and causal-IV forest:, the above shapes are:

(To see the above, search _get_alpha_and_pointJ in EconML codebase)

Now, we see we can construct our y as we discussed above before it is passed into Cython Criterion: y = np.hstack((y, alpha, pointJ)), which results in the corresponding shape (n_samples, n_outputs_y + n_outputs_A + n_outputs_A * n_outputs_A), which is what we alluded to earlier.

How is the criterion actually computed now, where we want to minimize the variance, or the mean-squared-error:

In order to compute this efficiently, we need to compute a proxy label that is used to compute impurity in left/right child node.

These define pointJ(node) and theta(node), which are defined as:

Now, as with all Tree-Criterion objects, we need to define how to compute the impurity_improvement and the proxy_impurity_improvement:

with this in mind, we want to re-implement a simplified version of LinearGRFCriterion in our codebase. Note: rho[I,k] is node-dependent, so it must be recomputed for every sample for

Constraints on Causal tree models

  1. Causal IV forest requires Z.shape[1] == T.shape[1]. I.e. must be exactly identified IVs.
adam2392 commented 10 months ago

Splitter

When it comes to the splitter, the code is a bit more simple, as most of the computation is passed off to the Criterion object. However, this splitter is fundamentally different from that in scikit-learn because it leverages two criterion objects that contain the training and validation dataset (i.e. structure building and leaf-setting for honesty).

This is used to ensure: the number of treatment (i.e. T) and outcome (i.e. y) variables are balanced across both children nodes and there is enough heterogeneity to predict CATE (i.e. P(y=1 | X=x, T=t)). However, this requires using the validation data to also help track stopping conditions for the splitter.

additional hyper parameters:

  1. balanced_ness: how many samples in left vs right
  2. minimum_eigenvalue_tolernace_of_leaf: how low an eigenvalue can be for a node

additional metrics to track:

This is simple enough to implement if we DON'T require the validation dataset criterion to also be passed in.

On one hand we don't want to pass in the validation dataset, otw it is very weird and departs from our API. On the other hand, without the validation dataset that helps us set leaf nodes, we won't know if the leaf nodes might be heavily degenerate meaning they do not have a lot of heterogeneity. Because we predict the CATE by dropping a sample through the tree, and then computing the tree's CATE estimate by comparing y[t==1] - y[t==0], which would not work if all the samples in the leaf node correspond to t == 1 for example. This presents a challenge because if we include the validation criterion object, we have to add this weird extra logic.

It should be fine, but if there's a more elegant solution that say we can abstract away in our sklearn-fork, then it could be an idea.