rasbt / mlxtend

A library of extension and helper modules for Python's data analysis and machine learning libraries.
https://rasbt.github.io/mlxtend/
Other
4.82k stars 853 forks source link

Enhance plot #1094

Open lcrmorin opened 2 months ago

lcrmorin commented 2 months ago

Describe the workflow you want to enable

Improve plot_splits for time series splits.

Currently the plot present some limitation. Here is an exemple with code:

import pandas as pd, numpy as np
import seaborn as sns, matplotlib.pyplot as plt

from sklearn.datasets import make_regression
from sklearn.dummy import DummyRegressor
from sklearn.metrics import mean_squared_error, make_scorer
from sklearn.model_selection import cross_val_score

from mlxtend.evaluate.time_series import GroupTimeSeriesSplit, plot_splits

X_test, y_test = [], []

start_year = 2010
end_year = 2020

for year in np.arange(start_year, end_year+1):
    X_year, y_year = make_regression(n_samples=5+(year-start_year), n_features=2, bias=0, noise=1, random_state=year)
    X_year = pd.DataFrame(X_year).rename(columns={0:'X1', 1:'X2'})
    X_year['year'] = year
    y_year = pd.Series(y_year)
    X_test.append(X_year)
    y_test.append(y_year)

X, y = pd.concat(X_test), pd.concat(y_test)

# modelisation
model = DummyRegressor(strategy="mean")
metric = mean_squared_error
cv_args = {"test_size": 1, 'n_splits': len(np.unique(X['year'])) - 1, 'window_type': 'expanding'}
cv = GroupTimeSeriesSplit(**cv_args)

scores = cross_val_score(model, X, y, cv=cv, groups=X['year'], scoring=make_scorer(metric))

plot_splits(X, y, X['year'], **cv_args)

gives the following plot:

9yb86cKN

As you can notice:

Describe your proposed solution

It might be a good idea to:

one option would be to only plot group with constant size:

PPWS0