reiinakano / scikit-plot

An intuitive library to add plotting functionality to scikit-learn objects.
MIT License
2.43k stars 285 forks source link

Class mismatch in skplt.plot_confusion_matrix when test has fewer classes than training #28

Closed ArmandGiraud closed 7 years ago

ArmandGiraud commented 7 years ago

Hello, I have an issue when trying to plot a confusion matrix fewer classes in my test set than in training. The class with 12 000+ occcurences in my sample should be labelled 'O' is it possible to get around this, or to include the label set manually as an input?

image it's not a big issue but would be nice if we could fix it. Thanks for your help

reiinakano commented 7 years ago

Hi @ArmandGiraud , could you give small sample code demonstrating the problem?

ArmandGiraud commented 7 years ago

Hi @reiinakano i tried to reproduce the issue with the digits dataset

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from scikitplot import plotters as skplt
import pandas as pd
digits = load_digits()
%matplotlib inline

X = digits.data
y = digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
test = pd.DataFrame(X_test)
test['target'] = y_test

test_reduced = test.ix[test.target!=0,:]
X_test_reduced = test_reduced.drop('target',axis=1)
y_test_reduced = test_reduced.target
print(set(y_test_reduced)) # we removed 0 from test set. 
lr = LogisticRegression(tol=5) # i set high tolerance so that the classifier sitll predicts some 0's
lr.fit(X_train,y_train)

y_pred  = lr.predict(X_test_reduced)
skplt.plot_confusion_matrix(y_test_reduced, y_pred)

image From the matrix we can read that the true values contains no occurrences of 1, but it actually does, this first line should refer to 0.

sorry for the ugly syntax, i'm kind of new to python. thx

reiinakano commented 7 years ago

Hi @ArmandGiraud , you're absolutely right, this was a bug in the implementation. Rest assured this has been fixed in #29 and is now in the v0.2.5 release. Just run pip install scikit-plot --upgrade and you should be good to go. :)

Thanks for using scikit-plot!

ArmandGiraud commented 7 years ago

@reiinakano Thanks a lot, that was a fast fix!

Thank you as well for developping this useful package, it saves a lot of time!