grf-labs / grf

Generalized Random Forests
https://grf-labs.github.io/grf/
GNU General Public License v3.0
957 stars 248 forks source link

Provide estimate of the Monte Carlo variance of a trained forest #288

Closed swager closed 5 years ago

swager commented 6 years ago

Once we've trained a forest, it's important to know whether we grew enough trees, or if more trees would improve the accuracy of the estimates. This issue is about providing users with an idea of how much they should expect test set error to improve if they could grow arbitrarily many trees.

One simple way to do this is to use a jackknife on trees: We look at what the forest would predict if we remove one tree at a time from the forest, and use any instability here to estimate the variance of the forest.

Let \hat{\theta}(x) denote the prediction made by a forest at x using B trees (theta could be a conditional mean in the regression case, or a CATE in the causal forest case); and let \hat{\theta}^{(-b)}(x) be the forest estimate produced with all but the b-th tree (i.e., using the B - 1 remaining trees). The jackknife estimate of Monte Carlo variance for the forest is then:

\hat{V}_MC(x) = (B-1) / B \sum_{b = 1}^B (\hat{\theta}^{(-b)}(x) - \hat{\theta}(x))^2.

In other words, the claim is that the forest estimates \hat{\theta}(x) suffer from excess noise \hat{V}_MC(x) due to the fact that we are using a finite number of trees. Ideally, one would want to grow enough trees such that the average of \hat{V}_MC(X) would be much smaller than the sample variance of the \hat{\theta}(X) (as this guarantees that most of the variance in estimates across different points is not due to Monte Carlo effects).

As a simple motivation for this idea, consider the case of the regression forest, where the forest actually predicts using an average of trees: \hat{\mu}(x) = 1/B \sum{b = 1}^B Tb(x), where Tb(.) is the b-th tree. In this case, we can verify that \hat{\mu}^{(-b)}(x) = 1/(B - 1) \sum{b' != b} Tb'(x), and so the jackknife recovers the usual variance estimate:

\hat{V}_MC(x) = 1 / (B (B-1)) \sum_{b = 1}^B (Tb(x) - \hat{\mu}(x))^2

In other cases, the jackknife estimate of Monte Carlo variance won't have such a simple form anymore, but it will still be straight forward to compute; see, e.g., this snippet for a part of the code where we compute something very similar for causal forests.

One important thing is that, if we're estimating MC variance at a training point, we need to properly account for out-of-bag sampling. In particularly, the whole forest no longer uses B trees (but rather Bi trees, the number of trees for which the i-th training example is OOB), and then leave-one-tree-out forests would use Bi - 1 trees (where we remove one additional OOB trees at a time), etc.

Some further questions:

In the long run, we'd want to use this variance measure to warn the user if they haven't grown enough trees; then, we should also provide a means to add more trees to a forest (see #272).

susanathey commented 6 years ago

I think we should always report.

I was not sure about the question of reporting everywhere or an average--if we compute everywhere we might as well including the vector, but we could have an option if overhead is an issue. We could also let the user pass a vector that says at which points to calculate it. That use case might arise if the analyst was interested in a subset of points, or a type of point (e.g. close to boundary). Though this functionality might more naturally go into predict, where we make predictions, estimate standard errors, and report monte carlo errors at test points, leaving the user to average them.

swager commented 6 years ago

Makes sense. Currently, we do OOB prediction at training time unless the user specifically asks us not to (because most workflows use OOB prediction, and this avoid us ever having to do OOB twice). So then a natural option is that, by default, we report MC variance at all points for OOB predictions. The typical user will then simply look at the average to check whether the forest has converged, but some users might be interested in whether we're more/less stable near the edges, etc.

halflearned commented 6 years ago

Let me see if I got this right. Sorry in advance is some of what follows is off.

  1. The quantity we're talking about is already being computed in the compute_debiased_error functions here (for regression forests) and here (for causal forests).

  2. The problem is that the error is being used to compute debiased error, but not being output to the user. So the big required actions are:

C++ side

R side


Now, some questions

halflearned commented 6 years ago

Another point: although these quantities are being computed when the user calls regression_train (or equivalent), they are not being calculated at "training time" in the sense of "before we leave C++". Within regression_train, the code flows to C++ and back once to fit the forest, and then a second time to compute predictions.

Should we plan to move this prediction step in C++, so that there would be only one C++ call? That does make sense somewhat, especially if we have in mind a user that will keep adding trees until they hit an MC error bound. However, this would likely requires larger changes to the C++ codebase.

jtibshirani commented 6 years ago

My intuition is that we should slightly rework the API of OptimizedPredictionStrategy. In particular, we can split the current compute_debiased_error method into two methods, compute_error, and compute_excess_error, and return these two values separately as part of the Prediction object. Callers would then be responsible for combining these two estimates in the appropriate way. The tuning procedure, for example, will calculate the debiased error through a simple subtraction. Since the API will not be as self-explanatory, I think that it will be important to document what each of these values represents, and give examples of how we expect them to be used.

Other thoughts:

(By the way, it looks like quantile_forest does not have this option)

Right, quantile_forest does not currently support debiased error estimates. This is true of all forests that are based on DefaultPredictionStrategy -- without precomputed sufficient statistics, the debiased error is not as inconvenient to calculate.

When a user calls predict on unseen, entirely new data, should we also output monte-carlo error?

It seems reasonable to allow the 'excess error' to be computed for a new dataset, through some option like compute.excess.error. I think that by default, we should always compute this estimate for OOB prediction, but only compute it for standard prediction if the user explicitly requests it.

Should we plan to move this prediction step in C++, so that there would be only one C++ call?

In this first pass at the feature, I don't think we need to focus on this optimization. If you profile the code, however, and find that combining the two C++ calls could provide a speed advantage, then it'd be a good change to make in a follow-up.

jtibshirani commented 5 years ago

Closed by #327 (thanks @halflearned!)