DistrictDataLabs / yellowbrick

Visual analysis and diagnostic tools to facilitate machine learning model selection.
http://www.scikit-yb.org/
Apache License 2.0
4.3k stars 559 forks source link

Add test for Sklearn Pipeline - CVScores #1253

Closed lwgray closed 2 years ago

lwgray commented 2 years ago

Describe the issue ModelVisualizers should be tested to see if they would work within pipelines. This addresses issue #498 and PR #1245

This is being addressed in PR #1254

The test should cover these visualizers

The sample test could look like this

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split as tts

from yellowbrick.classifier import ConfusionMatrix

iris = load_iris()
X = iris.data
y = iris.target
classes = iris.target_names

X_train, X_test, y_train, y_test = tts(X, y, test_size=0.2, random_state=42)

iris_cm = Pipeline([
    ('minmax', MinMaxScaler()), 
    ('confusion', ConfusionMatrix(LogisticRegression(multi_class="auto", solver="liblinear"), classes=classes,
                           label_encoder={0: 'setosa', 1: 'versicolor', 2: 'virginica'}))
])

iris_cm.fit(X_train, y_train)
iris_cm.score(X_test, y_test)
self.assert_images_similar(iris_cm, tol=??? should be set to similar test if needed)
from sklearn.neural_network import MLPClassifier
from yellowbrick.classifier import ConfusionMatrix

from sklearn.model_selection import train_test_split as tts
from sklearn.datasets import make_classification
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler

X, y = make_classification(
    n_samples=400,
    n_features=20,
    n_informative=8,
    n_redundant=8,
    n_classes=2,
    n_clusters_per_class=4,
    random_state=27,
)

X_train, X_test, y_train, y_test = tts(X, y, test_size=0.2, random_state=42)

model = Pipeline([
    ('minmax', MinMaxScaler()), 
    ('mlp', MLPClassifier()), 
]) 
viz = ConfusionMatrix(model)
viz.fit(X_train, y_train, )
viz.score(X_test, y_test)
self.assert_images_similar(viz, tol=??? should be set to similar test if needed))

@DistrictDataLabs/team-oz-maintainers