reiinakano / scikit-plot

An intuitive library to add plotting functionality to scikit-learn objects.
MIT License
2.43k stars 285 forks source link

ValueError: Found input variables with inconsistent numbers of samples #94

Open AntonioAntovski opened 6 years ago

AntonioAntovski commented 6 years ago

I'm trying to plot the ROC curve, but I get ValueError: Found input variables with inconsistent numbers of samples. Here's the code I use:

`skplt.metrics.plot_roc(labels_test.values, pred_w2v_cnn.values)

plt.show()`

Both labels_test.values and pred_w2v_cnn.values have the same length and both are of type np.ndarray. I'd be thankful if anyone can help me to solve this problem.

reiinakano commented 6 years ago

It will be easier to debug if you could post a minimal reproducible sample code that shows the error

AntonioAntovski commented 6 years ago

Here's the code I use:

`labels_test = test.label pred_w2v_cnn = pd.read_csv("predicted_word2vec_cnn.csv", sep=',', header=0, names=['index', 0, 1, 2, 3, 4, 5, 6])

test_labels = labels_test.values.reshape((len(labels_test.values), 1))

skplt.metrics.plot_roc(labels_test.values, pred_w2v_cnn.values) plt.show()`

Shape of test_labels: (143455, ) Shape of pred_w2v_cnn: (143455, 8)

I tried to reshape the test_labels to (143455, 1), but that didn't work either.

lugq1990 commented 6 years ago

@AntonioAntovski 'plot_roc' function is based on sklearn's 'roc_curve' function, this function will check input data shape. Maybe you should not use the 'index' column, because your label is 7-classes, but you give your prediction probability result is 8D, so raise this error. Drop it, then plot again.

For testing, this is new code:

from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import scikitplot as skplt
from sklearn.linear_model import LogisticRegression
from collections import Counter

iris = load_iris()
x, y = iris.data, iris.target
lr = LogisticRegression()
lr.fit(x, y)
<> this is model prediction prob result.
prob = lr.predict_proba(x)
tmp = np.random.random((len(y), 4))

<> this will work.
skplt.metrics.plot_roc(y, prob)

print('Different Classes Count res: ', Counter(y))
<> Because label is 3-classes, but given object result 'tmp' is 4D, so 
<> this failed, raise error: ValueError: Found input variables with inconsistent numbers of samples
skplt.metrics.plot_roc(y, tmp)
plt.show()