DistrictDataLabs / yellowbrick

Visual analysis and diagnostic tools to facilitate machine learning model selection.
http://www.scikit-yb.org/
Apache License 2.0
4.3k stars 559 forks source link

ClassificationReport axes not rendered properly when in a subplot #1201

Closed h-joshi closed 2 years ago

h-joshi commented 3 years ago

Describe the bug ClassificationReport not displaying classification outcomes on Y axis and measures (Precision, Recall, etc.) on X axis image

To Reproduce

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

from yellowbrick.classifier import ClassificationReport, ConfusionMatrix, ClassPredictionError, DiscriminationThreshold
from yellowbrick.datasets import load_spam

dataset = load_spam()

X_train, X_test, y_train, y_test = train_test_split(dataset[0], dataset[1], test_size=0.33, random_state=42)

def pref_report(estimator, title, X_train, y_train, X_test, y_test):
    plt.rcParams["font.family"] = "Arial"

    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(17, 12))
    fig.suptitle(title)

    visualizer1 = ClassificationReport(estimator(random_state=1), support=True, encoder={0:'Not spam', 1:'Spam'}, ax=ax[0,0])
    visualizer1.fit(X_train, y_train)
    visualizer1.score(X_test, y_test)

    visualizer2 = ConfusionMatrix(estimator(random_state=1), encoder={0:'Not spam', 1:'Spam'}, ax=ax[0,1])
    visualizer2.fit(X_train, y_train)
    visualizer2.score(X_test, y_test)

    for tick in ax[0,1].get_xticklabels():
        tick.set_rotation(30)

    visualizer3 = DiscriminationThreshold(estimator(random_state=1), encoder={0:'Not spam', 1:'Spam'}, ax=ax[1,0])
    visualizer3.fit(X_train, y_train)

    visualizer4 = ClassPredictionError(estimator(random_state=1), encoder={0:'Not spam', 1:'Spam'}, ax=ax[1,1])
    visualizer4.fit(X_train, y_train)
    visualizer4.score(X_test, y_test)

    for tick in ax[1,1].get_xticklabels():
        tick.set_rotation(30)

    fig.legend()

pref_report(RandomForestClassifier, "Baseline performance on test data (n = {})".format(X_test.shape[0]), X_train, y_train, X_test, y_test)

Dataset Did you use a specific dataset to produce the bug? Where can we access it?

Expected behavior Expected Classification report (top left hand plot) to display classification outcomes on Y axis and Precision, Recall, etc. on X axis

Desktop (please complete the following information):

bbengfort commented 3 years ago

@h-joshi thank you for submitting this bug report - we've been noticing issues with the ClassificationReport as well. By any chance do you have time to look into this and possibly submit a PR to fix it?

bbengfort commented 2 years ago

Closed by #1210