Open AntonioAntovski opened 6 years ago
It will be easier to debug if you could post a minimal reproducible sample code that shows the error
Here's the code I use:
`labels_test = test.label pred_w2v_cnn = pd.read_csv("predicted_word2vec_cnn.csv", sep=',', header=0, names=['index', 0, 1, 2, 3, 4, 5, 6])
skplt.metrics.plot_roc(labels_test.values, pred_w2v_cnn.values) plt.show()`
Shape of test_labels: (143455, ) Shape of pred_w2v_cnn: (143455, 8)
I tried to reshape the test_labels to (143455, 1), but that didn't work either.
@AntonioAntovski 'plot_roc' function is based on sklearn's 'roc_curve' function, this function will check input data shape. Maybe you should not use the 'index' column, because your label is 7-classes, but you give your prediction probability result is 8D, so raise this error. Drop it, then plot again.
For testing, this is new code:
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import scikitplot as skplt
from sklearn.linear_model import LogisticRegression
from collections import Counter
iris = load_iris()
x, y = iris.data, iris.target
lr = LogisticRegression()
lr.fit(x, y)
<> this is model prediction prob result.
prob = lr.predict_proba(x)
tmp = np.random.random((len(y), 4))
<> this will work.
skplt.metrics.plot_roc(y, prob)
print('Different Classes Count res: ', Counter(y))
<> Because label is 3-classes, but given object result 'tmp' is 4D, so
<> this failed, raise error: ValueError: Found input variables with inconsistent numbers of samples
skplt.metrics.plot_roc(y, tmp)
plt.show()
I'm trying to plot the ROC curve, but I get ValueError: Found input variables with inconsistent numbers of samples. Here's the code I use:
`skplt.metrics.plot_roc(labels_test.values, pred_w2v_cnn.values)
plt.show()`
Both labels_test.values and pred_w2v_cnn.values have the same length and both are of type np.ndarray. I'd be thankful if anyone can help me to solve this problem.