wildtreetech / advanced-comp-2017

💻 Material for a course on applied machine-learning for scientists. Taught at EPFL in spring 2017
23 stars 13 forks source link

Definition of correlation of trees in a random forest #8

Open sharkovsky opened 7 years ago

sharkovsky commented 7 years ago

Hello,

I am going down the rabbit hole of the definition of correlation of decision trees in a random forest.

For those who don't have time to read this wall of text, here's a quick summary.

tl; dr: what is the exact definition of correlation between trees in a random forest?

And how do we interpret this value?

long explanation

At first, I naively thought one could define it as

definition 1: correlation = correlation in the predictions of all the trees in a forest

However, I was having some doubts about my intuition, and Shaina Race's comment on this Quora question confirmed my doubts. Essentially, the way I understand her comment is: this definition is not in line with intuition, because why would I want correlation in the predictions to be low? Actually, If most of the trees are getting the right answers most of the time, correlation would be high but the model itself would be pretty good! Moreover, this would not give any indication about the robustness nor the generalisation power of the ensemble. She seems to suggest another definition:

definition 2: correlation = correlation in the errors

This definition seems nicer, because it seems to be intuitively closer to a notion of robustness of the total ensemble.

Unfortunately, I was down the rabbit hole and could not stop sliding. I started thinking whether the exercise asked for correlation, but instead meant variance. What confused me is that on many sources (e.g. sklearn user guide) random forests are cited as a method to reduce variance, and not correlation. Now variance in this case has a pretty precise meaning

variance = variance in the predictions

However, I was a bit lost because I wasn't sure how to extend this notion from a regression problem to a classification problem (especially a multi-class classification problem). I found this paper by P. Domingos about the bias-variance-noise decomposition in a general setting, it seemed quite math-heavy but ultimately proved to be decently readable. However, I have questions about it, in particular how the constants in front of the bias and variance terms ($c_1$ and $c_2$ in the paper) affect our interpretation of their values.

betatim commented 7 years ago

Sorry for not being more clear: what I was thinking about was simply the pearson correlation coefficient between the output values of tree1 and tree2. If a forest has 100 trees in it then I'd calculate the average correlation between them and compare it to a different forest with different settings.


Rough sketch I have in my head of how I'd start to look at this (maybe you need to construct the 10 trees yourself to really get the fine control over the randomisation strategy instead of delegating to RandomForestClassifier)

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

X, y = make_classification(n_samples=800, random_state=3)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, random_state=4)

rgr = RandomForestClassifier(n_estimators=10, max_features=None, bootstrap=False)
rgr.fit(X_train, y_train)

corr = np.zeros((10, 10))
for i in range(10):
    for j in range(10):
        rgr1 = rgr.estimators_[i]
        rgr2 = rgr.estimators_[j]
        corr[i,j] = np.corrcoef(rgr1.predict_proba(X_test)[:,0], rgr2.predict_proba(X_test)[:,0])[0, 1]

plt.imshow(corr, interpolation='none')

which produces the following plot: image

betatim commented 7 years ago

(feeling very old here: can you help me find the comments for the quora question?)


The reason to think about this is: Like you say if the trees are completely uncorrelated they are probably just predicting randomly (not useful). However when they are 100% correlated then averaging them will not reduce the variance of the predictions. And forests are all about reducing the variance through averaging as trees (as a generalisation) tend to be low bias but high variance models (a fully developed (aka very deep) tree will have low bias but if you grow it on a new independent sample of the data you are likely to get a different tree (-> high variance))

sharkovsky commented 7 years ago

Thanks for your comment!

I have a followup question: at the end of the process you described above, I essentially get a symmetric, diagonal matrix of n_estimators-by-n_estimators entries, where the entry in position i,j corresponds to the pearson correlation coefficient between the outputs of the i-th and j-th trees.

Now, what would you use as a global measure of the correlation in the ensemble? Some norm of the matrix? or the average of values in the matrix (could this be somehow affected by n_estimators?) ? What are the consequences of this choice?

Thanks!

P.S. in the quora page, if you scroll towards the bottom, there are links to upvote, downvote and comments. Although they don't look clickable, you need to actually click on "comments" and then you can see them.

betatim commented 7 years ago

I would use the average of all the pairs (without ordering, so only i-j not i-j and -j-i), but yeah: choices :) Would be interesting to see if the conclusions change depending on what you do.