rasbt / mlxtend

A library of extension and helper modules for Python's data analysis and machine learning libraries.
https://rasbt.github.io/mlxtend/
Other
4.86k stars 856 forks source link

TypeError: plot_confusion_matrix() got an unexpected keyword argument 'ax' #733

Closed Schmidtbit closed 3 years ago

Schmidtbit commented 3 years ago

plot_confusion_matrix() doesn't appear to accept an axis argument. However the documentation does indicate that you can pass in an axis.

I am trying to plot on an existing figure into subplots, but getting this error:

TypeError: plot_confusion_matrix() got an unexpected keyword argument 'ax'

rasbt commented 3 years ago

could you try plot_confusion_matrix(... , axis=...) instead of plot_confusion_matrix(... , ax=...) and see if that works?

Schmidtbit commented 3 years ago

I did as you said and got this new error:


/opt/anaconda3/envs/py37/lib/python3.7/site-packages/mlxtend/plotting/plot_confusion_matrix.py in plot_confusion_matrix(conf_mat, hide_spines, hide_ticks, figsize, cmap, colorbar, show_absolute, show_normed, class_names, figure, axis)
     84         fig, ax = figure, axis
     85 
---> 86     ax.grid(False)
     87     if cmap is None:
     88         cmap = plt.cm.Blues

AttributeError: 'numpy.ndarray' object has no attribute 'grid'```
rasbt commented 3 years ago

Based on the error, it sounds like you provided a numpy array instead of axis object. The following code example would work:

from mlxtend.plotting import plot_confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

binary1 = np.array([[4, 1],
                   [1, 2]])

fig, axis = plt.subplots(1, 1)

fig, ax = plot_confusion_matrix(conf_mat=binary1, axis=axis)
plt.show()

if you have something like

fig, axis = plt.subplots(1, 2)

then "axis" will be a numpy array of multiple axis. But since the confusion matrix is just a single plot, it would be unclear which axis to use. You could select a specific one, e.g, by axis=axis[0] in that case or axis=axis[1]

Schmidtbit commented 3 years ago

That worked! Thanks!

But in order for it to work I had to turn off the colorbar. Otherwise I got:

/opt/anaconda3/envs/py37/lib/python3.7/site-packages/mlxtend/plotting/plot_confusion_matrix.py in plot_confusion_matrix(conf_mat, hide_spines, hide_ticks, figsize, cmap, colorbar, show_absolute, show_normed, class_names, figure, axis)
     97 
     98     if colorbar:
---> 99         fig.colorbar(matshow)
    100 
    101     for i in range(conf_mat.shape[0]):

AttributeError: 'NoneType' object has no attribute 'colorbar'
Schmidtbit commented 3 years ago

However, the axis labels only show up for the last subplot.

Screen Shot 2020-10-01 at 3 57 41 PM
fig, ax = plt.subplots(1,3,figsize=(20,10))
for i, clf in enumerate(['DeepNeuralNetwork','RandomForest','GradientBoosted']):
    plot_confusion_matrix(conf_mat=cm_arr[clf],
                        colorbar=False,
                        show_absolute=False,
                        show_normed=True,
                        class_names=traindf['PartStatus'].unique(),
                        axis=ax[i])
    ax[i].set_title(clf)
plt.show()
Schmidtbit commented 3 years ago

I was able to get around the formatting with:

fig, ax = plt.subplots(1,3,figsize=(25,10),sharex=True, sharey=True)
fig.text(0.5, 0.04, 'True Label', ha='center',fontsize=20)
fig.text(0.04, 0.5, 'Predicted Label', va='center', rotation='vertical',fontsize=20)
for i, clf in enumerate(['DeepNeuralNetwork','RandomForest','GradientBoosted']):
    plot_confusion_matrix(conf_mat=cm_arr[clf],
                        colorbar=False,
                        show_absolute=False,
                        show_normed=True,
                        class_names=None,
                        axis=ax[i])
    ax[i].set_title(clf, fontsize=20)
    ax[i].set_xticks([0,1,2])
    ax[i].set_xticklabels(traindf['PartStatus'].unique())
    ax[i].set_yticks([0,1,2])
    ax[i].set_yticklabels(traindf['PartStatus'].unique()[::-1])
    ax[i].set_ylabel('')
    ax[i].set_xlabel('')
plt.show()

Screen Shot 2020-10-01 at 4 21 12 PM

rasbt commented 3 years ago

Glad to hear that you got it to work. Looks really nice. We should probably add something like this to the documentation as an example for how to make a plot with multiple subpanels.