TeamHG-Memex / eli5

A library for debugging/inspecting machine learning classifiers and explaining their predictions
http://eli5.readthedocs.io
MIT License
2.75k stars 331 forks source link

BUG ? `eli5.show_weights` displayed standard deviation does not agree with the values in `feature_importances_std_` #365

Closed seralouk closed 4 years ago

seralouk commented 4 years ago

The PermutationImportance object has some nice attributes like feature_importances_ and feature_importances_std_.

To visualize in an HTML style this attributes I used eli5.show_weights function. However, I noticed that the displayed standard deviation does not agree with the values in feature_importances_std_.

_*More specifically, I can see that the displayed HTML values are equal to `feature_importancesstd 2`. Why is that ?**_

Code:

from sklearn import datasets
import eli5
from eli5.sklearn import PermutationImportance
from sklearn.svm import SVC, SVR

# import some data to play with
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features.
y = iris.target

clf = SVC()
perms = PermutationImportance(clf, n_iter=1000, cv=10).fit(X, y)

print(perms.feature_importances_)
print(perms.feature_importances_std_)
print(perms.feature_importances_std_* 2)
eli5.show_weights(perms)

Screenshot: https://pasteboard.co/IWLg5Fe.png

seralouk commented 4 years ago

Found the *2 but WHY?

It's in the template generating the feature importances html table.

https://github.com/TeamHG-Memex/eli5/blob/63e99182dc682bbf225355c80a24807396a747b6/eli5/templates/feature_importances.html

        {% if not fw.std is none %}
            ± {{ "%0.4f"|format(2 * fw.std) }}
        {% endif %}
kmike commented 4 years ago

More specifically, I can see that the displayed HTML values are equal to feature_importancesstd * 2. Why is that ?

value ± 2*sigma gives a ~95% confidence interval; I think that's common to use such confidence interval. See https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule

seralouk commented 4 years ago

That's not correct. value ± 2 * (sigma/sqrt(n)) is the 95% interval. See last sentence in "Cumulative distribution function" section. I think this should be corrected in the eli5 source codes.

aghoshpub commented 4 years ago

Thanks @seralouk for digging into this. I think this issue is completely related to https://github.com/TeamHG-Memex/eli5/issues/316 and can be merged.

lopuhin commented 4 years ago

That's not correct. value ± 2 * (sigma/sqrt(n)) is the 95% interval.

@seralouk this is the confidence interval for the value of the mean of n values drawn from the given normal distribution, which indeed becomes tighter as n increases - but I don't see how this is relevant here.

When we give x +- y for the weights, we mean to say that the true value of the weights lies in the given range with 95% probability. The issue is that our choice of 2 sigma is rather arbitrary. Do you know what would be the standard way to show means along with standard deviations? I've seen mean (std) being used but I think it would be confusing for more people.

lopuhin commented 4 years ago

I think this issue is completely related to #316 and can be merged.

Indeed this is the same issue, since #316 was posted earlier, let's continue there.