gsganden / model_inspector

A uniform interface to a curated set of methods for inspecting machine learning models
https://gsganden.github.io/model_inspector/
Apache License 2.0
4 stars 0 forks source link

Ensure appropriate behavior for `GridSearchCV`, etc. #7

Closed gsganden closed 2 years ago

gsganden commented 2 years ago

This code plots the model as if it were a classification model rather than a regression model:

import sklearn.datasets
from model_inspector import get_inspector
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import GridSearchCV

X_diabetes, y_diabetes = sklearn.datasets.load_diabetes(return_X_y=True, as_frame=True)
X_diabetes = X_diabetes.iloc[:, [0]]

grid = GridSearchCV(LinearRegression(), {}).fit(X_diabetes, y_diabetes)

inspector = get_inspector(grid, X_diabetes, y_diabetes)

ax = inspector.plot()

Screen Shot 2021-07-29 at 7 40 34 AM

It would ideally recognize that grid.best_estimator_ is a regression model and plot it accordingly:

inspector = get_inspector(grid.best_estimator_, X_diabetes, y_diabetes)

ax = inspector.plot()

Screen Shot 2021-07-29 at 7 41 24 AM

Alternatively, it could refuse to plot, requiring the user to make the inspector from grid.best_estimator_.