linkedin / greykite

A flexible, intuitive and fast forecasting library
BSD 2-Clause "Simplified" License
1.8k stars 104 forks source link

`design_info` is needed to make predictions on new data #121

Open samuelefiorini opened 1 year ago

samuelefiorini commented 1 year ago

According to the documentation

The design info is not useful in general (to reproduce results, make predictions), and not dumping it will save a lot of time.

However trying to make predictions on new data using a model restored via forecaster.load_forecast_result(path, load_design_info=False) leads to the following error.

File /var/lang/lib/python3.10/site-packages/greykite/algo/common/ml_models.py:715, in predict_ml(fut_df, trained_model)
    713 y_col = trained_model["y_col"]
    714 ml_model = trained_model["ml_model"]
--> 715 x_design_info = trained_model["x_design_info"]
    716 drop_intercept_col = trained_model["drop_intercept_col"]
    717 min_admissible_value = trained_model["min_admissible_value"]

KeyError: 'x_design_info'

This totally makes sense when looking at greykite.algo.common.ml_models.predict_ml as the variable x_design_info is used by patsy to build the design matrix (see here).

On the other hand, dumping design_info does not only imply dealing with a bigger artifact, but may be impossible due to system limitations on the generated filename.

As an example, this is what happens in my case.

OSError: [Errno 36] File name too long: '/opt/ml/model/5f4cafc99b894af398c02013e13348e2/artifacts/forecast_result/grid_search/best_estimator_/steps/2_key/1_key/model_dict/x_design_info__value__/factor_infos/EvalFactor("C(Q(\'dow_hr\'), levels=[\'1_00\', \'1_01\', \'1_02\', \'1_03\', \'1_04\', \'1_05\', \'1_06\', \'1_07\', \'1_08\', \'1_09\', \'1_10\', \'1_11\', \'1_12\', \'1_13\', \'1_14\', \'1_15\', \'1_16\', \'1_17\', \'1_18\', \'1_19\', \'1_20\', \'1_21\', \'1_22\', \'1_23\', \'2_00\', \'2_01\', \'2_02\', \'2_03\', \'2_04\', \'2_05\', \'2_06\', \'2_07\', \'2_08\', \'2_09\', \'2_10\', \'2_11\', \'2_12\', \'2_13\', \'2_14\', \'2_15\', \'2_16\', \'2_17\', \'2_18\', \'2_19\', \'2_20\', \'2_21\', \'2_22\', \'2_23\', \'3_00\', \'3_01\', \'3_02\', \'3_03\', \'3_04\', \'3_05\', \'3_06\', \'3_07\', \'3_08\', \'3_09\', \'3_10\', \'3_11\', \'3_12\', \'3_13\', \'3_14\', \'3_15\', \'3_16\', \'3_17\', \'3_18\', \'3_19\', \'3_20\', \'3_21\', \'3_22\', \'3_23\', \'4_00\', \'4_01\', \'4_02\', \'4_03\', \'4_04\', \'4_05\', \'4_06\', \'4_07\', \'4_08\', \'4_09\', \'4_10\', \'4_11\', \'4_12\', \'4_13\', \'4_14\', \'4_15\', \'4_16\', \'4_17\', \'4_18\', \'4_19\', \'4_20\', \'4_21\', \'4_22\', \'4_23\', \'5_00\', \'5_01\', \'5_02\', \'5_03\', \'5_04\', \'5_05\', \'5_06\', \'5_07\', \'5_08\', \'5_09\', \'5_10\', \'5_11\', \'5_12\', \'5_13\', \'5_14\', \'5_15\', \'5_16\', \'5_17\', \'5_18\', \'5_19\', \'5_20\', \'5_21\', \'5_22\', \'5_23\', \'6_00\', \'6_01\', \'6_02\', \'6_03\', \'6_04\', \'6_05\', \'6_06\', \'6_07\', \'6_08\', \'6_09\', \'6_10\', \'6_11\', \'6_12\', \'6_13\', \'6_14\', \'6_15\', \'6_16\', \'6_17\', \'6_18\', \'6_19\', \'6_20\', \'6_21\', \'6_22\', \'6_23\', \'7_00\', \'7_01\', \'7_02\', \'7_03\', \'7_04\', \'7_05\', \'7_06\', \'7_07\', \'7_08\', \'7_09\', \'7_10\', \'7_11\', \'7_12\', \'7_13\', \'7_14\', \'7_15\', \'7_16\', \'7_17\', \'7_18\', \'7_19\', \'7_20\', \'7_21\', \'7_22\', \'7_23\'])")__key__.pkl'

Any ideas on how to work around this issue?

samuelefiorini commented 1 year ago

The KeyError above is raised only at the first prediction attempt. See for instance the following snippet for an ugly workaround.

import warnings

import pandas as pd
from greykite.framework.benchmark.data_loader_ts import DataLoaderTS
from greykite.framework.templates.autogen.forecast_config import (
    EvaluationPeriodParam,
    ForecastConfig,
    MetadataParam,
    ModelComponentsParam,
)
from greykite.framework.templates.forecaster import Forecaster
from greykite.framework.templates.model_templates import ModelTemplateEnum
from greykite.framework.utils.result_summary import summarize_grid_search_results

warnings.filterwarnings("ignore")

def prepare_bikesharing_data():
    """Loads bike-sharing data and adds proper regressors."""
    dl = DataLoaderTS()
    agg_func = {"count": "sum", "tmin": "mean", "tmax": "mean", "pn": "mean"}
    df = dl.load_bikesharing(agg_freq="daily", agg_func=agg_func)

    # There are some zero values which cause issue for MAPE
    # This adds a small number to all data to avoid that issue
    value_col = "count"
    df[value_col] += 10
    # We drop last value as data might be incorrect as original data is hourly
    df.drop(df.tail(1).index, inplace=True)
    # We only use data from 2018 for demonstration purposes (run time is shorter)
    df = df.loc[df["ts"] > "2018-01-01"]
    df.reset_index(drop=True, inplace=True)

    print(f"\n df.tail(): \n {df.tail()}")

    # Creates useful regressors from existing raw regressors
    df["bin_pn"] = (df["pn"] > 5).map(float)
    df["bin_heavy_pn"] = (df["pn"] > 20).map(float)
    df.columns = [
        "ts",
        value_col,
        "regressor_tmin",
        "regressor_tmax",
        "regressor_pn",
        "regressor_bin_pn",
        "regressor_bin_heavy_pn",
    ]

    forecast_horizon = 7
    train_df = df.copy()
    test_df = df.tail(forecast_horizon).reset_index(drop=True)
    # When using the pipeline (as done in the ``fit_forecast`` below),
    # fitting and prediction are done in one step
    # Therefore for demonstration purpose we remove the response values of last 7 days.
    # This is needed because we are using regressors,
    # and future regressor data must be augmented to ``df``.
    # We mimic that by removal of the values of the response.
    train_df.loc[(len(train_df) - forecast_horizon) : len(train_df), value_col] = None

    print(f"train_df shape: \n {train_df.shape}")
    print(f"test_df shape: \n {test_df.shape}")
    print(f"train_df.tail(14): \n {train_df.tail(14)}")
    print(f"test_df: \n {test_df}")

    return {"train_df": train_df, "test_df": test_df}

def fit_forecast(df, time_col, value_col):
    """Fits a daily model for this use case.

    The daily model is a generic silverkite model with regressors.
    """
    meta_data_params = MetadataParam(
        time_col=time_col,
        value_col=value_col,
        freq="D",
    )

    # Autoregression to be used in the function
    autoregression = {
        "autoreg_dict": {
            "lag_dict": {"orders": [1, 2, 3]},
            "agg_lag_dict": {"orders_list": [[7, 7 * 2, 7 * 3]], "interval_list": [(1, 7), (8, 7 * 2)]},
            "series_na_fill_func": lambda s: s.bfill().ffill(),
        },
        "fast_simulation": True,
    }

    # Changepoints configuration
    # The config includes changepoints both in trend and seasonality
    changepoints = {
        "changepoints_dict": {
            "method": "auto",
            "yearly_seasonality_order": 15,
            "resample_freq": "2D",
            "actual_changepoint_min_distance": "100D",
            "potential_changepoint_distance": "50D",
            "no_changepoint_distance_from_end": "50D",
        },
        "seasonality_changepoints_dict": {
            "method": "auto",
            "yearly_seasonality_order": 15,
            "resample_freq": "2D",
            "actual_changepoint_min_distance": "100D",
            "potential_changepoint_distance": "50D",
            "no_changepoint_distance_from_end": "50D",
        },
    }

    regressor_cols = [
        "regressor_tmin",
        "regressor_bin_pn",
        "regressor_bin_heavy_pn",
    ]

    # Model parameters
    model_components = ModelComponentsParam(
        growth=dict(growth_term="linear"),
        seasonality=dict(
            yearly_seasonality=[15],
            quarterly_seasonality=[False],
            monthly_seasonality=[False],
            weekly_seasonality=[True],
            daily_seasonality=[False],
        ),
        custom=dict(
            fit_algorithm_dict=dict(fit_algorithm="ridge"), extra_pred_cols=None, normalize_method="statistical"
        ),
        regressors=dict(regressor_cols=regressor_cols),
        autoregression=autoregression,
        uncertainty=dict(uncertainty_dict=None),
        events=dict(holiday_lookup_countries=["US"]),
        changepoints=changepoints,
    )

    # Evaluation is done on same ``forecast_horizon`` as desired for output
    forecast_horizon = 7
    evaluation_period_param = EvaluationPeriodParam(
        test_horizon=None,
        cv_horizon=forecast_horizon,
        cv_min_train_periods=365 * 2,
        cv_expanding_window=True,
        cv_use_most_recent_splits=False,
        cv_periods_between_splits=None,
        cv_periods_between_train_test=0,
        cv_max_splits=5,
    )

    # Runs the forecast model using "SILVERKITE" template
    forecaster = Forecaster()
    result = forecaster.run_forecast_config(
        df=df,
        config=ForecastConfig(
            model_template=ModelTemplateEnum.SILVERKITE.name,
            coverage=0.95,
            forecast_horizon=forecast_horizon,
            metadata_param=meta_data_params,
            evaluation_period_param=evaluation_period_param,
            model_components_param=model_components,
        ),
    )

    # Gets cross-validation results
    grid_search = result.grid_search
    cv_results = summarize_grid_search_results(grid_search=grid_search, decimals=2, cv_report_metrics=None)
    cv_results = cv_results.transpose()
    cv_results = pd.DataFrame(cv_results)
    cv_results.columns = ["err_value"]
    cv_results["err_name"] = cv_results.index
    cv_results = cv_results.reset_index(drop=True)
    cv_results = cv_results[["err_name", "err_value"]]

    print(f"\n cv_results: \n {cv_results}")

    return forecaster

data = prepare_bikesharing_data()
df = data["train_df"]
time_col = "ts"
value_col = "count"

forecaster = fit_forecast(df=df, time_col=time_col, value_col=value_col)
# Dump forecaster without design info
forecaster.dump_forecast_result(
    "/tmp/artifacts", overwrite_exist_dir=True, dump_design_info=False, object_name="forecast_result"
)

# Load a new forecaster object
new_forecaster = Forecaster()
new_forecaster.load_forecast_result("/tmp/artifacts")
new_trained_estimator = new_forecaster.forecast_result.model[-1]
try:
    y_pred = new_trained_estimator.predict(data["test_df"])
except KeyError:
    print("Second time's the charm")
    y_pred = new_trained_estimator.predict(data["test_df"])

print(y_pred)

which produces the following output


 df.tail():
             ts  count  tmin  tmax   pn
602 2019-08-27  12216  17.2  26.7  0.0
603 2019-08-28  11401  18.3  27.8  0.0
604 2019-08-29  12685  16.7  28.9  0.0
605 2019-08-30  12097  14.4  32.8  0.0
606 2019-08-31  11281  17.8  31.1  0.0
train_df shape:
 (607, 7)
test_df shape:
 (7, 7)
train_df.tail(14):
             ts    count  regressor_tmin  regressor_tmax  regressor_pn  regressor_bin_pn  regressor_bin_heavy_pn
593 2019-08-18   9655.0            22.2            35.6           0.3               0.0                     0.0
594 2019-08-19  10579.0            21.1            37.2           0.0               0.0                     0.0
595 2019-08-20   8898.0            22.2            36.1           0.0               0.0                     0.0
596 2019-08-21  11648.0            21.7            35.0           1.8               0.0                     0.0
597 2019-08-22  11724.0            21.7            35.0          30.7               1.0                     1.0
598 2019-08-23   8158.0            17.8            23.3           1.8               0.0                     0.0
599 2019-08-24  12475.0            16.7            26.1           0.0               0.0                     0.0
600 2019-08-25      NaN            15.6            26.7           0.0               0.0                     0.0
601 2019-08-26      NaN            17.2            25.0           0.0               0.0                     0.0
602 2019-08-27      NaN            17.2            26.7           0.0               0.0                     0.0
603 2019-08-28      NaN            18.3            27.8           0.0               0.0                     0.0
604 2019-08-29      NaN            16.7            28.9           0.0               0.0                     0.0
605 2019-08-30      NaN            14.4            32.8           0.0               0.0                     0.0
606 2019-08-31      NaN            17.8            31.1           0.0               0.0                     0.0
test_df:
           ts  count  regressor_tmin  regressor_tmax  regressor_pn  regressor_bin_pn  regressor_bin_heavy_pn
0 2019-08-25  11634            15.6            26.7           0.0               0.0                     0.0
1 2019-08-26  11747            17.2            25.0           0.0               0.0                     0.0
2 2019-08-27  12216            17.2            26.7           0.0               0.0                     0.0
3 2019-08-28  11401            18.3            27.8           0.0               0.0                     0.0
4 2019-08-29  12685            16.7            28.9           0.0               0.0                     0.0
5 2019-08-30  12097            14.4            32.8           0.0               0.0                     0.0
6 2019-08-31  11281            17.8            31.1           0.0               0.0                     0.0
Fitting 1 folds for each of 1 candidates, totalling 1 fits

 cv_results:
                                              err_name                                          err_value
0                                      rank_test_MAPE                                                  1
1                                      mean_test_MAPE                                              10.28
2                                     split_test_MAPE                                           (10.28,)
3                                     mean_train_MAPE                                              21.75
4                                              params                                                 []
5                 param_estimator__yearly_seasonality                                                 15
6                 param_estimator__weekly_seasonality                                               True
7                   param_estimator__uncertainty_dict                                               None
8                  param_estimator__training_fraction                                               None
9                  param_estimator__train_test_thresh                                               None
10                   param_estimator__time_properties  {'period': 86400, 'simple_freq': SimpleTimeFre...
11                    param_estimator__simulation_num                                                 10
12     param_estimator__seasonality_changepoints_dict  {'method': 'auto', 'yearly_seasonality_order':...
13                  param_estimator__remove_intercept                                              False
14                    param_estimator__regressor_cols  [regressor_tmin, regressor_bin_pn, regressor_b...
15             param_estimator__regression_weight_col                                               None
16             param_estimator__quarterly_seasonality                                              False
17              param_estimator__origin_for_time_vars                                               None
18                  param_estimator__normalize_method                                        statistical
19               param_estimator__monthly_seasonality                                              False
20              param_estimator__min_admissible_value                                               None
21  param_estimator__max_weekly_seas_interaction_o...                                                  2
22  param_estimator__max_daily_seas_interaction_order                                                  5
23              param_estimator__max_admissible_value                                               None
24             param_estimator__lagged_regressor_dict                                               None
25      param_estimator__holidays_to_model_separately                                               auto
26         param_estimator__holiday_pre_post_num_dict                                               None
27              param_estimator__holiday_pre_num_days                                                  2
28             param_estimator__holiday_post_num_days                                                  2
29          param_estimator__holiday_lookup_countries                                               [US]
30                       param_estimator__growth_term                                             linear
31                param_estimator__fit_algorithm_dict                         {'fit_algorithm': 'ridge'}
32              param_estimator__feature_sets_enabled                                               auto
33                   param_estimator__fast_simulation                                               True
34                   param_estimator__extra_pred_cols                                               None
35                param_estimator__explicit_pred_cols                                               None
36                    param_estimator__drop_pred_cols                                               None
37                 param_estimator__daily_seasonality                                              False
38        param_estimator__daily_event_shifted_effect                                               None
39       param_estimator__daily_event_neighbor_impact                                               None
40               param_estimator__daily_event_df_dict                                               None
41                 param_estimator__changepoints_dict  {'method': 'auto', 'yearly_seasonality_order':...
42                      param_estimator__autoreg_dict  {'lag_dict': {'orders': [1, 2, 3]}, 'agg_lag_d...
43                  param_estimator__auto_seasonality                                              False
44                      param_estimator__auto_holiday                                              False
45                       param_estimator__auto_growth                                              False
46                                   split_train_MAPE                                           (21.75,)
47                                      mean_fit_time                                              25.47
48                                       std_fit_time                                                0.0
49                                    mean_score_time                                              17.96
50                                     std_score_time                                                0.0
51                                   split0_test_MAPE                                              10.28
52                                      std_test_MAPE                                                0.0
53                                  split0_train_MAPE                                              21.75
54                                     std_train_MAPE                                                0.0
Second time's the charm
            ts      forecast                          quantile_summary      err_std  forecast_lower  forecast_upper
0   2018-01-02   4867.991614   (1236.6364564500004, 8499.346772353505)  1852.766268     1236.636456     8499.346772
1   2018-01-03   5538.375314    (2582.389599005536, 8494.361028039748)  1508.183690     2582.389599     8494.361028
2   2018-01-04   5328.133261    (2085.659451338936, 8570.607070004411)  1654.353771     2085.659451     8570.607070
3   2018-01-05   4798.627174    (731.2554335326386, 8865.998913687017)  2075.227796      731.255434     8865.998914
4   2018-01-06   4300.328558    (-608.899093611637, 9209.556209578057)  2504.754011     -608.899094     9209.556210
..         ...           ...                                       ...          ...             ...             ...
602 2019-08-27  11584.776481    (8293.792027862797, 14875.76093427293)  1679.104555     8293.792028    14875.760934
603 2019-08-28  11975.071674   (9294.749310736905, 14655.394038244023)  1367.536539     9294.749311    14655.394038
604 2019-08-29  11731.381551    (8433.28073238897, 15029.482369719592)  1682.735420     8433.280732    15029.482370
605 2019-08-30  11166.335793    (7029.302037145629, 15303.36954967026)  2110.770294     7029.302037    15303.369550
606 2019-08-31  10118.891291  (3968.9981066371283, 16268.784475116765)  3137.758261     3968.998107    16268.784475

[607 rows x 6 columns]
samuelefiorini commented 1 year ago

In order to reproduce the error above, see the following minimal script.

import warnings

warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd

from greykite.common.data_loader import DataLoader
from greykite.framework.templates.autogen.forecast_config import ForecastConfig
from greykite.framework.templates.autogen.forecast_config import MetadataParam
from greykite.framework.templates.autogen.forecast_config import ModelComponentsParam
from greykite.framework.templates.forecaster import Forecaster
from greykite.framework.templates.model_templates import ModelTemplateEnum

pd.options.plotting.backend = 'plotly'

# Defines inputs
df = DataLoader().load_bikesharing().tail(24*90)  # Input time series (pandas.DataFrame)
df['ts'] = pd.to_datetime(df['ts'])

forecast_horizon = 24*2

df.loc[df.index[-forecast_horizon:], 'count'] = np.nan

print(df.head())
print(df.tail())

config = ForecastConfig(
     metadata_param=MetadataParam(time_col="ts", value_col="count"),  # Column names in `df_train`
     model_template=ModelTemplateEnum.AUTO.name,  # AUTO model configuration
     forecast_horizon=forecast_horizon,   # Forecasts all the missing steps
     model_components_param=ModelComponentsParam(regressors={"regressor_cols": ['tmin', 'tmax', 'pn']}),
     coverage=0.95,         # 95% prediction intervals
)

# Creates forecasts
forecaster = Forecaster()
result = forecaster.run_forecast_config(df=df, config=config)

forecaster.dump_forecast_result(
    '/tmp/forecaster',
    object_name="object",
    dump_design_info=False,
    overwrite_exist_dir=True
)

# Recreate Forecaster
new_forecaster = Forecaster()
new_forecaster.load_forecast_result(
    '/tmp/forecaster',
    load_design_info=False
)
new_result = new_forecaster.forecast_result

new_pred = new_result.model.predict(df.rename(columns={'count': 'y'}))

print(new_pred)

which gives the following output

 date                  ts  count  tmin  tmax   pn
76261  2019-06-03 2019-06-03 01:00:00   35.0  11.7  23.3  0.0
76262  2019-06-03 2019-06-03 02:00:00   20.0  11.7  23.3  0.0
76263  2019-06-03 2019-06-03 03:00:00    9.0  11.7  23.3  0.0
76264  2019-06-03 2019-06-03 04:00:00   14.0  11.7  23.3  0.0
76265  2019-06-03 2019-06-03 05:00:00   37.0  11.7  23.3  0.0
             date                  ts  count  tmin  tmax   pn
78416  2019-08-31 2019-08-31 20:00:00    NaN  17.8  31.1  0.0
78417  2019-08-31 2019-08-31 21:00:00    NaN  17.8  31.1  0.0
78418  2019-08-31 2019-08-31 22:00:00    NaN  17.8  31.1  0.0
78419  2019-08-31 2019-08-31 23:00:00    NaN  17.8  31.1  0.0
78420  2019-09-01 2019-09-01 00:00:00    NaN  21.1  28.3  0.0
Fitting 3 folds for each of 1 candidates, totalling 3 fits
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-15-06ddd6580090> in <module>
----> 1 new_pred = new_result.model.predict(df.rename(columns={'count': 'y'}))

/opt/conda/lib/python3.7/site-packages/sklearn/utils/metaestimators.py in <lambda>(*args, **kwargs)
    111 
    112             # lambda, but not partial, allows help() to work with update_wrapper
--> 113             out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)  # noqa
    114         else:
    115 

/opt/conda/lib/python3.7/site-packages/sklearn/pipeline.py in predict(self, X, **predict_params)
    468         for _, name, transform in self._iter(with_final=False):
    469             Xt = transform.transform(Xt)
--> 470         return self.steps[-1][1].predict(Xt, **predict_params)
    471 
    472     @available_if(_final_estimator_has("fit_predict"))

/opt/conda/lib/python3.7/site-packages/greykite/sklearn/estimator/base_silverkite_estimator.py in predict(self, X, y)
    365             trained_model=self.model_dict,
    366             past_df=self.past_df,
--> 367             new_external_regressor_df=None)  # regressors are included in X
    368         pred_df = pred_res["fut_df"]
    369         x_mat = pred_res["x_mat"]

/opt/conda/lib/python3.7/site-packages/greykite/algo/forecast/silverkite/forecast_silverkite.py in predict(self, fut_df, trained_model, freq, past_df, new_external_regressor_df, include_err, force_no_sim, simulation_num, fast_simulation, na_fill_func)
   2266                     new_external_regressor_df=None,
   2267                     time_features_ready=False,
-> 2268                     regressors_ready=True)
   2269                 fut_df0 = pred_res["fut_df"]
   2270                 x_mat0 = pred_res["x_mat"]

/opt/conda/lib/python3.7/site-packages/greykite/algo/forecast/silverkite/forecast_silverkite.py in predict_no_sim(self, fut_df, trained_model, past_df, new_external_regressor_df, time_features_ready, regressors_ready)
   1226             pred_res = predict_ml_with_uncertainty(
   1227                 fut_df=features_df_fut,
-> 1228                 trained_model=trained_model)
   1229             fut_df = pred_res["fut_df"]
   1230             x_mat = pred_res["x_mat"]

/opt/conda/lib/python3.7/site-packages/greykite/algo/common/ml_models.py in predict_ml_with_uncertainty(fut_df, trained_model)
    558     pred_res = predict_ml(
    559         fut_df=fut_df,
--> 560         trained_model=trained_model)
    561 
    562     y_pred = pred_res["fut_df"][y_col]

/opt/conda/lib/python3.7/site-packages/greykite/algo/common/ml_models.py in predict_ml(fut_df, trained_model)
    501     y_col = trained_model["y_col"]
    502     ml_model = trained_model["ml_model"]
--> 503     x_design_info = trained_model["x_design_info"]
    504     min_admissible_value = trained_model["min_admissible_value"]
    505     max_admissible_value = trained_model["max_admissible_value"]

KeyError: 'x_design_info'

This is a serious issue, as it makes Greykite unusable when predictions needs to be made using a mdel restored from disk.

samuelefiorini commented 1 year ago

I have discovered a solution to address the issue. With the help of @andreaschiappacasse and his insightful intuition, we have successfully implemented a slightly modified version of the model dump & load process.

This modification prevents the system from generating pickle files with excessively long names. As a result, we can now dump models with the design info matrix and subsequently load them to perform predictions.

Please refer to the implementation below for more details

def dump_obj(obj, dir_name, obj_name="obj", dump_design_info=True, overwrite_exist_dir=False, top_level=True):
    """See `greykite.framework.templates.pickle_utils.dump_obj`."""
    # Checks if to dump design info.
    if (not dump_design_info) and (isinstance(obj, DesignInfo) or (isinstance(obj, str) and obj == "x_design_info")):
        return

    # Checks if directory already exists.
    if top_level:
        dir_already_exist = os.path.exists(dir_name)
        if dir_already_exist:
            if not overwrite_exist_dir:
                raise FileExistsError(
                    "The directory already exists. "
                    "Please either specify a new directory or "
                    "set overwrite_exist_dir to True to overwrite it."
                )
            else:
                if os.path.isdir(dir_name):
                    # dir exists as a directory.
                    shutil.rmtree(dir_name)
                else:
                    # dir exists as a file.
                    os.remove(dir_name)

    # Creates the directory.
    # None top-level may write to the same directory,
    # so we allow existing directory in this case.
    try:
        os.mkdir(dir_name)
    except FileExistsError:
        pass

    # Start dumping recursively.
    try:
        # Attempts to directly dump the object.
        dill.dump(obj, open(os.path.join(dir_name, f"{obj_name}.pkl"), "wb"))
    except NotImplementedError:
        # Direct dumping fails.
        # Removed the failed file.
        try:
            os.remove(os.path.join(dir_name, f"{obj_name}.pkl"))
        except FileNotFoundError:
            pass
        # Attempts to do recursive dumping depending on the object type.
        if isinstance(obj, OrderedDict):
            # For OrderedDict (there are a lot in `pasty.design_info.DesignInfo`),
            # recursively dumps the keys and values, because keys can be class instances
            # and unpicklable, too.
            # The keys and values have index number appended to the front,
            # so the order is kept.
            dill.dump("ordered_dict", open(os.path.join(dir_name, f"{obj_name}.type"), "wb"))  # type "ordered_dict"
            for i, (key, value) in enumerate(obj.items()):
                # name = str(key) # this is how it used to be in the "name too-long" version
                name = f"{i}_{str(hash(key))}"  # but we actually don't need to keep the key in name
                dump_obj(
                    key,
                    os.path.join(dir_name, obj_name),
                    f"{name}__key__",
                    dump_design_info=dump_design_info,
                    top_level=False,
                )
                dump_obj(
                    value,
                    os.path.join(dir_name, obj_name),
                    f"{name}__value__",
                    dump_design_info=dump_design_info,
                    top_level=False,
                )
        elif isinstance(obj, dict):
            # For regular dictionary,
            # recursively dumps the keys and values, because keys can be class instances
            # and unpicklable, too.
            # The order is not important.
            dill.dump("dict", open(os.path.join(dir_name, f"{obj_name}.type"), "wb"))  # type "dict"
            for key, value in obj.items():
                # name = str(key) # this is how it used to be in the "name too-long" version
                name = str(hash(key))  # but we actually don't need to keep the key in name
                dump_obj(
                    key,
                    os.path.join(dir_name, obj_name),
                    f"{name}__key__",
                    dump_design_info=dump_design_info,
                    top_level=False,
                )
                dump_obj(
                    value,
                    os.path.join(dir_name, obj_name),
                    f"{name}__value__",
                    dump_design_info=dump_design_info,
                    top_level=False,
                )
        elif isinstance(obj, (list, tuple)):
            # For list and tuples,
            # recursively dumps the elements.
            # The names have index number appended to the front,
            # so the order is kept.
            dill.dump(
                type(obj).__name__, open(os.path.join(dir_name, f"{obj_name}.type"), "wb")
            )  # type "list"/"tuple"
            for i, value in enumerate(obj):
                dump_obj(
                    value,
                    os.path.join(dir_name, obj_name),
                    f"{i}_key",
                    dump_design_info=dump_design_info,
                    top_level=False,
                )
        elif hasattr(obj, "__class__") and not isinstance(obj, type):
            # For class instance,
            # recursively dumps the attributes.
            dill.dump(obj.__class__, open(os.path.join(dir_name, f"{obj_name}.type"), "wb"))  # type is class itself
            for key, value in obj.__dict__.items():
                dump_obj(
                    value, os.path.join(dir_name, obj_name), key, dump_design_info=dump_design_info, top_level=False
                )
        else:
            # Other unrecognized unpicklable types, not common.
            print(f"I Don't recognize type {type(obj)}")

def load_obj(dir_name, obj=None, load_design_info=True):
        """See `greykite.framework.templates.pickle_utils.load_obj`."""
    # Checks if to load design info.
    if (not load_design_info) and (isinstance(obj, type) and obj == DesignInfo):
        return None

    # Gets file names in the level.
    files = os.listdir(dir_name)
    if not files:
        raise ValueError("dir is empty!")

    # Gets the type files if any.
    # Stores in a dictionary with key being the name and value being the loaded value.
    obj_types = {
        file.split(".")[0]: dill.load(open(os.path.join(dir_name, file), "rb")) for file in files if ".type" in file
    }

    # Gets directories and pickled files.
    # Every type must have a directory with the same name.
    directories = [file for file in files if os.path.isdir(os.path.join(dir_name, file))]
    if not all([directory in obj_types for directory in directories]):
        raise ValueError("type and directories do not match.")
    pickles = [file for file in files if ".pkl" in file]

    # Starts loading objects
    if obj is None:
        # obj is None indicates this is the top level directory.
        # This directory can either have 1 .pkl file, or 1 .type file associated with the directory of same name.
        if not obj_types:
            # The only 1 .pkl file case.
            if len(files) > 1:
                raise ValueError("Multiple elements found in top level.")
            return dill.load(open(os.path.join(dir_name, files[0]), "rb"))
        else:
            # The .type + dir case.
            if len(obj_types) > 1:
                raise ValueError("Multiple elements found in top level")
            obj_name = list(obj_types.keys())[0]
            obj_type = obj_types[obj_name]
            return load_obj(os.path.join(dir_name, obj_name), obj_type, load_design_info=load_design_info)
    else:
        # If obj is not None, does recursive loading depending on the obj type.
        if obj in ("list", "tuple"):
            # Object is list or tuple.
            # Fetches each element according to the number index to preserve orders.
            result = []
            # Order index is a number appended to the front.
            elements = sorted(pickles + directories, key=lambda x: int(x.split("_")[0]))
            # Recursively loads elements.
            for element in elements:
                if ".pkl" in element:
                    result.append(dill.load(open(os.path.join(dir_name, element), "rb")))
                else:
                    result.append(
                        load_obj(
                            os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
                        )
                    )
            if obj == "tuple":
                result = tuple(result)
            return result
        elif obj == "dict":
            # Object is a dictionary.
            # Fetches keys and values recursively.
            result = {}
            elements = pickles + directories
            keys = [element for element in elements if "__key__" in element]
            values = [element for element in elements if "__value__" in element]
            # Iterates through keys and finds the corresponding values.
            for element in keys:
                if ".pkl" in element:
                    key = dill.load(open(os.path.join(dir_name, element), "rb"))
                else:
                    key = load_obj(
                        os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
                    )
                # Value name could be either with .pkl or a directory.
                value_name = element.replace("__key__", "__value__")
                if ".pkl" in value_name:
                    value_name_alt = value_name.replace(".pkl", "")
                else:
                    value_name_alt = value_name + ".pkl"
                # Checks if value name is in the dir.
                if (value_name not in values) and (value_name_alt not in values):
                    raise FileNotFoundError(f"Value not found for key {key}.")
                value_name = value_name if value_name in values else value_name_alt
                # Gets the value.
                if ".pkl" in value_name:
                    value = dill.load(open(os.path.join(dir_name, value_name), "rb"))
                else:
                    value = load_obj(
                        os.path.join(dir_name, value_name), obj_types[value_name], load_design_info=load_design_info
                    )
                # Sets the key, value pair.
                result[key] = value
            return result
        elif obj == "ordered_dict":
            # Object is OrderedDict.
            # Fetches keys and values according to the number index to preserve orders.
            result = OrderedDict()
            # Order index is a number appended to the front.
            elements = sorted(pickles + directories, key=lambda x: int(x.split("_")[0]))
            # elements = pickles + directories
            keys = [element for element in elements if "__key__" in element]
            values = [element for element in elements if "__value__" in element]
            # Iterates through keys and finds the corresponding values.
            for element in keys:
                if ".pkl" in element:
                    key = dill.load(open(os.path.join(dir_name, element), "rb"))
                else:
                    key = load_obj(
                        os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
                    )
                value_name = element.replace("__key__", "__value__")
                # Value name could be either with .pkl or a directory.
                if ".pkl" in value_name:
                    value_name_alt = value_name.replace(".pkl", "")
                else:
                    value_name_alt = value_name + ".pkl"
                # Checks if value name is in the dir.
                if (value_name not in values) and (value_name_alt not in values):
                    raise FileNotFoundError(f"Value not found for key {key}.")
                value_name = value_name if value_name in values else value_name_alt
                # Gets the value.
                if ".pkl" in value_name:
                    value = dill.load(open(os.path.join(dir_name, value_name), "rb"))
                else:
                    value = load_obj(
                        os.path.join(dir_name, value_name), obj_types[value_name], load_design_info=load_design_info
                    )
                # Sets the key, value pair.
                result[key] = value
            return result
        elif inspect.isclass(obj):
            # Object is a class instance.
            # Creates the class instance and sets the attributes.
            # Some class has required args during initialization,
            # these args are pulled from attributes.
            init_params = list(inspect.signature(obj.__init__).parameters)  # init args
            elements = pickles + directories
            # Gets the attribute names and their values in a dictionary.
            values = {}
            for element in elements:
                if ".pkl" in element:
                    values[element.split(".")[0]] = dill.load(open(os.path.join(dir_name, element), "rb"))
                else:
                    values[element] = load_obj(
                        os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
                    )
            # Gets the init args from values.
            init_dict = {key: value for key, value in values.items() if key in init_params}
            # Some attributes has a "_" at the beginning.
            init_dict.update(
                {key[1:]: value for key, value in values.items() if (key[1:] in init_params and key[0] == "_")}
            )
            # ``design_info`` does not have column_names attribute,
            # which is required during init.
            # The column_names param is pulled from the column_name_indexes attribute.
            # This can be omitted once we allow dumping @property attributes.
            if "column_names" in init_params:
                init_dict["column_names"] = values["column_name_indexes"].keys()
            # Creates the instance.
            result = obj(**init_dict)
            # Sets the attributes.
            for key, value in values.items():
                setattr(result, key, value)
            return result
        else:
            # Raises an error if the object is not recognized.
            # This typically does not happen when the source file is dumped
            # with the `dump_obj` function.
            raise ValueError(f"Object {obj} is not recognized.")
pjgaudre commented 11 months ago

Thank you for raising this issue and providing an alternative. My colleague @sayanpatra was working on model dumps this quarter. I'll ask him to take a look at it.