reiinakano / scikit-plot

An intuitive library to add plotting functionality to scikit-learn objects.
MIT License
2.43k stars 284 forks source link

Throws error "IndexError: too many indices for array" when trying to plot roc for binary classification #79

Closed TarunTater closed 6 years ago

TarunTater commented 6 years ago

For binary classification, when I input numpy arrays having test label and test probabilities, it throws the following error :


y_true = np.array(ytest)
y_probas = np.array(p_test)
skplt.metrics.plot_roc_curve(y_true,y_probas)
plt.show()
IndexError                                Traceback (most recent call last)
<ipython-input-49-1b02f082006a> in <module>()
----> 1 skplt.metrics.plot_roc_curve(y_true,y_probas)
      2 plt.show()

/Users/tarun/anaconda/envs/gl-env/lib/python2.7/site-packages/scikitplot/metrics.pyc in plot_roc_curve(y_true, y_probas, title, curves, ax, figsize, cmap, title_fontsize, text_fontsize)
    247     roc_auc = dict()
    248     for i in range(len(classes)):
--> 249         fpr[i], tpr[i], _ = roc_curve(y_true, probas[:, i],
    250                                       pos_label=classes[i])
    251         roc_auc[i] = auc(fpr[i], tpr[i])

IndexError: too many indices for array
lugq1990 commented 6 years ago

In the scikit-plot, if you want to plot the ROC curve for a binary classification problem, the needed prob is a 2-dimentional array, maybe you just past a 1-D prob to the plot_roc_curve(). So you need to make the 1-D to 2-D just like this : np.concatenate((1-y_probas,y_probas),axis=1). Try this .

reiinakano commented 6 years ago

@lugq1990 is correct. I never liked scikit-learn's roc_curve because I had to pick a positive class. I designed plot_roc_curve to work directly with the output of predict_proba.

Although I do realize now it could be useful for binary classifiers with decision_function, whose output is a 1D array. Perhaps somebody could extend the code of plot_roc_curve to work with a 1D array.

TarunTater commented 6 years ago

@lugq1990 - ya, changing the numpy array to 2D worked. Thanks for your suggestion!

victorheuer commented 2 years ago

This exact problem is occurring for me. I'll try @lugq1990 's solution. Thank you!