Open samuelefiorini opened 1 year ago
The KeyError
above is raised only at the first prediction attempt. See for instance the following snippet for an ugly workaround.
import warnings
import pandas as pd
from greykite.framework.benchmark.data_loader_ts import DataLoaderTS
from greykite.framework.templates.autogen.forecast_config import (
EvaluationPeriodParam,
ForecastConfig,
MetadataParam,
ModelComponentsParam,
)
from greykite.framework.templates.forecaster import Forecaster
from greykite.framework.templates.model_templates import ModelTemplateEnum
from greykite.framework.utils.result_summary import summarize_grid_search_results
warnings.filterwarnings("ignore")
def prepare_bikesharing_data():
"""Loads bike-sharing data and adds proper regressors."""
dl = DataLoaderTS()
agg_func = {"count": "sum", "tmin": "mean", "tmax": "mean", "pn": "mean"}
df = dl.load_bikesharing(agg_freq="daily", agg_func=agg_func)
# There are some zero values which cause issue for MAPE
# This adds a small number to all data to avoid that issue
value_col = "count"
df[value_col] += 10
# We drop last value as data might be incorrect as original data is hourly
df.drop(df.tail(1).index, inplace=True)
# We only use data from 2018 for demonstration purposes (run time is shorter)
df = df.loc[df["ts"] > "2018-01-01"]
df.reset_index(drop=True, inplace=True)
print(f"\n df.tail(): \n {df.tail()}")
# Creates useful regressors from existing raw regressors
df["bin_pn"] = (df["pn"] > 5).map(float)
df["bin_heavy_pn"] = (df["pn"] > 20).map(float)
df.columns = [
"ts",
value_col,
"regressor_tmin",
"regressor_tmax",
"regressor_pn",
"regressor_bin_pn",
"regressor_bin_heavy_pn",
]
forecast_horizon = 7
train_df = df.copy()
test_df = df.tail(forecast_horizon).reset_index(drop=True)
# When using the pipeline (as done in the ``fit_forecast`` below),
# fitting and prediction are done in one step
# Therefore for demonstration purpose we remove the response values of last 7 days.
# This is needed because we are using regressors,
# and future regressor data must be augmented to ``df``.
# We mimic that by removal of the values of the response.
train_df.loc[(len(train_df) - forecast_horizon) : len(train_df), value_col] = None
print(f"train_df shape: \n {train_df.shape}")
print(f"test_df shape: \n {test_df.shape}")
print(f"train_df.tail(14): \n {train_df.tail(14)}")
print(f"test_df: \n {test_df}")
return {"train_df": train_df, "test_df": test_df}
def fit_forecast(df, time_col, value_col):
"""Fits a daily model for this use case.
The daily model is a generic silverkite model with regressors.
"""
meta_data_params = MetadataParam(
time_col=time_col,
value_col=value_col,
freq="D",
)
# Autoregression to be used in the function
autoregression = {
"autoreg_dict": {
"lag_dict": {"orders": [1, 2, 3]},
"agg_lag_dict": {"orders_list": [[7, 7 * 2, 7 * 3]], "interval_list": [(1, 7), (8, 7 * 2)]},
"series_na_fill_func": lambda s: s.bfill().ffill(),
},
"fast_simulation": True,
}
# Changepoints configuration
# The config includes changepoints both in trend and seasonality
changepoints = {
"changepoints_dict": {
"method": "auto",
"yearly_seasonality_order": 15,
"resample_freq": "2D",
"actual_changepoint_min_distance": "100D",
"potential_changepoint_distance": "50D",
"no_changepoint_distance_from_end": "50D",
},
"seasonality_changepoints_dict": {
"method": "auto",
"yearly_seasonality_order": 15,
"resample_freq": "2D",
"actual_changepoint_min_distance": "100D",
"potential_changepoint_distance": "50D",
"no_changepoint_distance_from_end": "50D",
},
}
regressor_cols = [
"regressor_tmin",
"regressor_bin_pn",
"regressor_bin_heavy_pn",
]
# Model parameters
model_components = ModelComponentsParam(
growth=dict(growth_term="linear"),
seasonality=dict(
yearly_seasonality=[15],
quarterly_seasonality=[False],
monthly_seasonality=[False],
weekly_seasonality=[True],
daily_seasonality=[False],
),
custom=dict(
fit_algorithm_dict=dict(fit_algorithm="ridge"), extra_pred_cols=None, normalize_method="statistical"
),
regressors=dict(regressor_cols=regressor_cols),
autoregression=autoregression,
uncertainty=dict(uncertainty_dict=None),
events=dict(holiday_lookup_countries=["US"]),
changepoints=changepoints,
)
# Evaluation is done on same ``forecast_horizon`` as desired for output
forecast_horizon = 7
evaluation_period_param = EvaluationPeriodParam(
test_horizon=None,
cv_horizon=forecast_horizon,
cv_min_train_periods=365 * 2,
cv_expanding_window=True,
cv_use_most_recent_splits=False,
cv_periods_between_splits=None,
cv_periods_between_train_test=0,
cv_max_splits=5,
)
# Runs the forecast model using "SILVERKITE" template
forecaster = Forecaster()
result = forecaster.run_forecast_config(
df=df,
config=ForecastConfig(
model_template=ModelTemplateEnum.SILVERKITE.name,
coverage=0.95,
forecast_horizon=forecast_horizon,
metadata_param=meta_data_params,
evaluation_period_param=evaluation_period_param,
model_components_param=model_components,
),
)
# Gets cross-validation results
grid_search = result.grid_search
cv_results = summarize_grid_search_results(grid_search=grid_search, decimals=2, cv_report_metrics=None)
cv_results = cv_results.transpose()
cv_results = pd.DataFrame(cv_results)
cv_results.columns = ["err_value"]
cv_results["err_name"] = cv_results.index
cv_results = cv_results.reset_index(drop=True)
cv_results = cv_results[["err_name", "err_value"]]
print(f"\n cv_results: \n {cv_results}")
return forecaster
data = prepare_bikesharing_data()
df = data["train_df"]
time_col = "ts"
value_col = "count"
forecaster = fit_forecast(df=df, time_col=time_col, value_col=value_col)
# Dump forecaster without design info
forecaster.dump_forecast_result(
"/tmp/artifacts", overwrite_exist_dir=True, dump_design_info=False, object_name="forecast_result"
)
# Load a new forecaster object
new_forecaster = Forecaster()
new_forecaster.load_forecast_result("/tmp/artifacts")
new_trained_estimator = new_forecaster.forecast_result.model[-1]
try:
y_pred = new_trained_estimator.predict(data["test_df"])
except KeyError:
print("Second time's the charm")
y_pred = new_trained_estimator.predict(data["test_df"])
print(y_pred)
which produces the following output
df.tail():
ts count tmin tmax pn
602 2019-08-27 12216 17.2 26.7 0.0
603 2019-08-28 11401 18.3 27.8 0.0
604 2019-08-29 12685 16.7 28.9 0.0
605 2019-08-30 12097 14.4 32.8 0.0
606 2019-08-31 11281 17.8 31.1 0.0
train_df shape:
(607, 7)
test_df shape:
(7, 7)
train_df.tail(14):
ts count regressor_tmin regressor_tmax regressor_pn regressor_bin_pn regressor_bin_heavy_pn
593 2019-08-18 9655.0 22.2 35.6 0.3 0.0 0.0
594 2019-08-19 10579.0 21.1 37.2 0.0 0.0 0.0
595 2019-08-20 8898.0 22.2 36.1 0.0 0.0 0.0
596 2019-08-21 11648.0 21.7 35.0 1.8 0.0 0.0
597 2019-08-22 11724.0 21.7 35.0 30.7 1.0 1.0
598 2019-08-23 8158.0 17.8 23.3 1.8 0.0 0.0
599 2019-08-24 12475.0 16.7 26.1 0.0 0.0 0.0
600 2019-08-25 NaN 15.6 26.7 0.0 0.0 0.0
601 2019-08-26 NaN 17.2 25.0 0.0 0.0 0.0
602 2019-08-27 NaN 17.2 26.7 0.0 0.0 0.0
603 2019-08-28 NaN 18.3 27.8 0.0 0.0 0.0
604 2019-08-29 NaN 16.7 28.9 0.0 0.0 0.0
605 2019-08-30 NaN 14.4 32.8 0.0 0.0 0.0
606 2019-08-31 NaN 17.8 31.1 0.0 0.0 0.0
test_df:
ts count regressor_tmin regressor_tmax regressor_pn regressor_bin_pn regressor_bin_heavy_pn
0 2019-08-25 11634 15.6 26.7 0.0 0.0 0.0
1 2019-08-26 11747 17.2 25.0 0.0 0.0 0.0
2 2019-08-27 12216 17.2 26.7 0.0 0.0 0.0
3 2019-08-28 11401 18.3 27.8 0.0 0.0 0.0
4 2019-08-29 12685 16.7 28.9 0.0 0.0 0.0
5 2019-08-30 12097 14.4 32.8 0.0 0.0 0.0
6 2019-08-31 11281 17.8 31.1 0.0 0.0 0.0
Fitting 1 folds for each of 1 candidates, totalling 1 fits
cv_results:
err_name err_value
0 rank_test_MAPE 1
1 mean_test_MAPE 10.28
2 split_test_MAPE (10.28,)
3 mean_train_MAPE 21.75
4 params []
5 param_estimator__yearly_seasonality 15
6 param_estimator__weekly_seasonality True
7 param_estimator__uncertainty_dict None
8 param_estimator__training_fraction None
9 param_estimator__train_test_thresh None
10 param_estimator__time_properties {'period': 86400, 'simple_freq': SimpleTimeFre...
11 param_estimator__simulation_num 10
12 param_estimator__seasonality_changepoints_dict {'method': 'auto', 'yearly_seasonality_order':...
13 param_estimator__remove_intercept False
14 param_estimator__regressor_cols [regressor_tmin, regressor_bin_pn, regressor_b...
15 param_estimator__regression_weight_col None
16 param_estimator__quarterly_seasonality False
17 param_estimator__origin_for_time_vars None
18 param_estimator__normalize_method statistical
19 param_estimator__monthly_seasonality False
20 param_estimator__min_admissible_value None
21 param_estimator__max_weekly_seas_interaction_o... 2
22 param_estimator__max_daily_seas_interaction_order 5
23 param_estimator__max_admissible_value None
24 param_estimator__lagged_regressor_dict None
25 param_estimator__holidays_to_model_separately auto
26 param_estimator__holiday_pre_post_num_dict None
27 param_estimator__holiday_pre_num_days 2
28 param_estimator__holiday_post_num_days 2
29 param_estimator__holiday_lookup_countries [US]
30 param_estimator__growth_term linear
31 param_estimator__fit_algorithm_dict {'fit_algorithm': 'ridge'}
32 param_estimator__feature_sets_enabled auto
33 param_estimator__fast_simulation True
34 param_estimator__extra_pred_cols None
35 param_estimator__explicit_pred_cols None
36 param_estimator__drop_pred_cols None
37 param_estimator__daily_seasonality False
38 param_estimator__daily_event_shifted_effect None
39 param_estimator__daily_event_neighbor_impact None
40 param_estimator__daily_event_df_dict None
41 param_estimator__changepoints_dict {'method': 'auto', 'yearly_seasonality_order':...
42 param_estimator__autoreg_dict {'lag_dict': {'orders': [1, 2, 3]}, 'agg_lag_d...
43 param_estimator__auto_seasonality False
44 param_estimator__auto_holiday False
45 param_estimator__auto_growth False
46 split_train_MAPE (21.75,)
47 mean_fit_time 25.47
48 std_fit_time 0.0
49 mean_score_time 17.96
50 std_score_time 0.0
51 split0_test_MAPE 10.28
52 std_test_MAPE 0.0
53 split0_train_MAPE 21.75
54 std_train_MAPE 0.0
Second time's the charm
ts forecast quantile_summary err_std forecast_lower forecast_upper
0 2018-01-02 4867.991614 (1236.6364564500004, 8499.346772353505) 1852.766268 1236.636456 8499.346772
1 2018-01-03 5538.375314 (2582.389599005536, 8494.361028039748) 1508.183690 2582.389599 8494.361028
2 2018-01-04 5328.133261 (2085.659451338936, 8570.607070004411) 1654.353771 2085.659451 8570.607070
3 2018-01-05 4798.627174 (731.2554335326386, 8865.998913687017) 2075.227796 731.255434 8865.998914
4 2018-01-06 4300.328558 (-608.899093611637, 9209.556209578057) 2504.754011 -608.899094 9209.556210
.. ... ... ... ... ... ...
602 2019-08-27 11584.776481 (8293.792027862797, 14875.76093427293) 1679.104555 8293.792028 14875.760934
603 2019-08-28 11975.071674 (9294.749310736905, 14655.394038244023) 1367.536539 9294.749311 14655.394038
604 2019-08-29 11731.381551 (8433.28073238897, 15029.482369719592) 1682.735420 8433.280732 15029.482370
605 2019-08-30 11166.335793 (7029.302037145629, 15303.36954967026) 2110.770294 7029.302037 15303.369550
606 2019-08-31 10118.891291 (3968.9981066371283, 16268.784475116765) 3137.758261 3968.998107 16268.784475
[607 rows x 6 columns]
In order to reproduce the error above, see the following minimal script.
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from greykite.common.data_loader import DataLoader
from greykite.framework.templates.autogen.forecast_config import ForecastConfig
from greykite.framework.templates.autogen.forecast_config import MetadataParam
from greykite.framework.templates.autogen.forecast_config import ModelComponentsParam
from greykite.framework.templates.forecaster import Forecaster
from greykite.framework.templates.model_templates import ModelTemplateEnum
pd.options.plotting.backend = 'plotly'
# Defines inputs
df = DataLoader().load_bikesharing().tail(24*90) # Input time series (pandas.DataFrame)
df['ts'] = pd.to_datetime(df['ts'])
forecast_horizon = 24*2
df.loc[df.index[-forecast_horizon:], 'count'] = np.nan
print(df.head())
print(df.tail())
config = ForecastConfig(
metadata_param=MetadataParam(time_col="ts", value_col="count"), # Column names in `df_train`
model_template=ModelTemplateEnum.AUTO.name, # AUTO model configuration
forecast_horizon=forecast_horizon, # Forecasts all the missing steps
model_components_param=ModelComponentsParam(regressors={"regressor_cols": ['tmin', 'tmax', 'pn']}),
coverage=0.95, # 95% prediction intervals
)
# Creates forecasts
forecaster = Forecaster()
result = forecaster.run_forecast_config(df=df, config=config)
forecaster.dump_forecast_result(
'/tmp/forecaster',
object_name="object",
dump_design_info=False,
overwrite_exist_dir=True
)
# Recreate Forecaster
new_forecaster = Forecaster()
new_forecaster.load_forecast_result(
'/tmp/forecaster',
load_design_info=False
)
new_result = new_forecaster.forecast_result
new_pred = new_result.model.predict(df.rename(columns={'count': 'y'}))
print(new_pred)
which gives the following output
date ts count tmin tmax pn
76261 2019-06-03 2019-06-03 01:00:00 35.0 11.7 23.3 0.0
76262 2019-06-03 2019-06-03 02:00:00 20.0 11.7 23.3 0.0
76263 2019-06-03 2019-06-03 03:00:00 9.0 11.7 23.3 0.0
76264 2019-06-03 2019-06-03 04:00:00 14.0 11.7 23.3 0.0
76265 2019-06-03 2019-06-03 05:00:00 37.0 11.7 23.3 0.0
date ts count tmin tmax pn
78416 2019-08-31 2019-08-31 20:00:00 NaN 17.8 31.1 0.0
78417 2019-08-31 2019-08-31 21:00:00 NaN 17.8 31.1 0.0
78418 2019-08-31 2019-08-31 22:00:00 NaN 17.8 31.1 0.0
78419 2019-08-31 2019-08-31 23:00:00 NaN 17.8 31.1 0.0
78420 2019-09-01 2019-09-01 00:00:00 NaN 21.1 28.3 0.0
Fitting 3 folds for each of 1 candidates, totalling 3 fits
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-15-06ddd6580090> in <module>
----> 1 new_pred = new_result.model.predict(df.rename(columns={'count': 'y'}))
/opt/conda/lib/python3.7/site-packages/sklearn/utils/metaestimators.py in <lambda>(*args, **kwargs)
111
112 # lambda, but not partial, allows help() to work with update_wrapper
--> 113 out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs) # noqa
114 else:
115
/opt/conda/lib/python3.7/site-packages/sklearn/pipeline.py in predict(self, X, **predict_params)
468 for _, name, transform in self._iter(with_final=False):
469 Xt = transform.transform(Xt)
--> 470 return self.steps[-1][1].predict(Xt, **predict_params)
471
472 @available_if(_final_estimator_has("fit_predict"))
/opt/conda/lib/python3.7/site-packages/greykite/sklearn/estimator/base_silverkite_estimator.py in predict(self, X, y)
365 trained_model=self.model_dict,
366 past_df=self.past_df,
--> 367 new_external_regressor_df=None) # regressors are included in X
368 pred_df = pred_res["fut_df"]
369 x_mat = pred_res["x_mat"]
/opt/conda/lib/python3.7/site-packages/greykite/algo/forecast/silverkite/forecast_silverkite.py in predict(self, fut_df, trained_model, freq, past_df, new_external_regressor_df, include_err, force_no_sim, simulation_num, fast_simulation, na_fill_func)
2266 new_external_regressor_df=None,
2267 time_features_ready=False,
-> 2268 regressors_ready=True)
2269 fut_df0 = pred_res["fut_df"]
2270 x_mat0 = pred_res["x_mat"]
/opt/conda/lib/python3.7/site-packages/greykite/algo/forecast/silverkite/forecast_silverkite.py in predict_no_sim(self, fut_df, trained_model, past_df, new_external_regressor_df, time_features_ready, regressors_ready)
1226 pred_res = predict_ml_with_uncertainty(
1227 fut_df=features_df_fut,
-> 1228 trained_model=trained_model)
1229 fut_df = pred_res["fut_df"]
1230 x_mat = pred_res["x_mat"]
/opt/conda/lib/python3.7/site-packages/greykite/algo/common/ml_models.py in predict_ml_with_uncertainty(fut_df, trained_model)
558 pred_res = predict_ml(
559 fut_df=fut_df,
--> 560 trained_model=trained_model)
561
562 y_pred = pred_res["fut_df"][y_col]
/opt/conda/lib/python3.7/site-packages/greykite/algo/common/ml_models.py in predict_ml(fut_df, trained_model)
501 y_col = trained_model["y_col"]
502 ml_model = trained_model["ml_model"]
--> 503 x_design_info = trained_model["x_design_info"]
504 min_admissible_value = trained_model["min_admissible_value"]
505 max_admissible_value = trained_model["max_admissible_value"]
KeyError: 'x_design_info'
This is a serious issue, as it makes Greykite unusable when predictions needs to be made using a mdel restored from disk.
I have discovered a solution to address the issue. With the help of @andreaschiappacasse and his insightful intuition, we have successfully implemented a slightly modified version of the model dump & load process.
This modification prevents the system from generating pickle files with excessively long names. As a result, we can now dump models with the design info matrix and subsequently load them to perform predictions.
Please refer to the implementation below for more details
def dump_obj(obj, dir_name, obj_name="obj", dump_design_info=True, overwrite_exist_dir=False, top_level=True):
"""See `greykite.framework.templates.pickle_utils.dump_obj`."""
# Checks if to dump design info.
if (not dump_design_info) and (isinstance(obj, DesignInfo) or (isinstance(obj, str) and obj == "x_design_info")):
return
# Checks if directory already exists.
if top_level:
dir_already_exist = os.path.exists(dir_name)
if dir_already_exist:
if not overwrite_exist_dir:
raise FileExistsError(
"The directory already exists. "
"Please either specify a new directory or "
"set overwrite_exist_dir to True to overwrite it."
)
else:
if os.path.isdir(dir_name):
# dir exists as a directory.
shutil.rmtree(dir_name)
else:
# dir exists as a file.
os.remove(dir_name)
# Creates the directory.
# None top-level may write to the same directory,
# so we allow existing directory in this case.
try:
os.mkdir(dir_name)
except FileExistsError:
pass
# Start dumping recursively.
try:
# Attempts to directly dump the object.
dill.dump(obj, open(os.path.join(dir_name, f"{obj_name}.pkl"), "wb"))
except NotImplementedError:
# Direct dumping fails.
# Removed the failed file.
try:
os.remove(os.path.join(dir_name, f"{obj_name}.pkl"))
except FileNotFoundError:
pass
# Attempts to do recursive dumping depending on the object type.
if isinstance(obj, OrderedDict):
# For OrderedDict (there are a lot in `pasty.design_info.DesignInfo`),
# recursively dumps the keys and values, because keys can be class instances
# and unpicklable, too.
# The keys and values have index number appended to the front,
# so the order is kept.
dill.dump("ordered_dict", open(os.path.join(dir_name, f"{obj_name}.type"), "wb")) # type "ordered_dict"
for i, (key, value) in enumerate(obj.items()):
# name = str(key) # this is how it used to be in the "name too-long" version
name = f"{i}_{str(hash(key))}" # but we actually don't need to keep the key in name
dump_obj(
key,
os.path.join(dir_name, obj_name),
f"{name}__key__",
dump_design_info=dump_design_info,
top_level=False,
)
dump_obj(
value,
os.path.join(dir_name, obj_name),
f"{name}__value__",
dump_design_info=dump_design_info,
top_level=False,
)
elif isinstance(obj, dict):
# For regular dictionary,
# recursively dumps the keys and values, because keys can be class instances
# and unpicklable, too.
# The order is not important.
dill.dump("dict", open(os.path.join(dir_name, f"{obj_name}.type"), "wb")) # type "dict"
for key, value in obj.items():
# name = str(key) # this is how it used to be in the "name too-long" version
name = str(hash(key)) # but we actually don't need to keep the key in name
dump_obj(
key,
os.path.join(dir_name, obj_name),
f"{name}__key__",
dump_design_info=dump_design_info,
top_level=False,
)
dump_obj(
value,
os.path.join(dir_name, obj_name),
f"{name}__value__",
dump_design_info=dump_design_info,
top_level=False,
)
elif isinstance(obj, (list, tuple)):
# For list and tuples,
# recursively dumps the elements.
# The names have index number appended to the front,
# so the order is kept.
dill.dump(
type(obj).__name__, open(os.path.join(dir_name, f"{obj_name}.type"), "wb")
) # type "list"/"tuple"
for i, value in enumerate(obj):
dump_obj(
value,
os.path.join(dir_name, obj_name),
f"{i}_key",
dump_design_info=dump_design_info,
top_level=False,
)
elif hasattr(obj, "__class__") and not isinstance(obj, type):
# For class instance,
# recursively dumps the attributes.
dill.dump(obj.__class__, open(os.path.join(dir_name, f"{obj_name}.type"), "wb")) # type is class itself
for key, value in obj.__dict__.items():
dump_obj(
value, os.path.join(dir_name, obj_name), key, dump_design_info=dump_design_info, top_level=False
)
else:
# Other unrecognized unpicklable types, not common.
print(f"I Don't recognize type {type(obj)}")
def load_obj(dir_name, obj=None, load_design_info=True):
"""See `greykite.framework.templates.pickle_utils.load_obj`."""
# Checks if to load design info.
if (not load_design_info) and (isinstance(obj, type) and obj == DesignInfo):
return None
# Gets file names in the level.
files = os.listdir(dir_name)
if not files:
raise ValueError("dir is empty!")
# Gets the type files if any.
# Stores in a dictionary with key being the name and value being the loaded value.
obj_types = {
file.split(".")[0]: dill.load(open(os.path.join(dir_name, file), "rb")) for file in files if ".type" in file
}
# Gets directories and pickled files.
# Every type must have a directory with the same name.
directories = [file for file in files if os.path.isdir(os.path.join(dir_name, file))]
if not all([directory in obj_types for directory in directories]):
raise ValueError("type and directories do not match.")
pickles = [file for file in files if ".pkl" in file]
# Starts loading objects
if obj is None:
# obj is None indicates this is the top level directory.
# This directory can either have 1 .pkl file, or 1 .type file associated with the directory of same name.
if not obj_types:
# The only 1 .pkl file case.
if len(files) > 1:
raise ValueError("Multiple elements found in top level.")
return dill.load(open(os.path.join(dir_name, files[0]), "rb"))
else:
# The .type + dir case.
if len(obj_types) > 1:
raise ValueError("Multiple elements found in top level")
obj_name = list(obj_types.keys())[0]
obj_type = obj_types[obj_name]
return load_obj(os.path.join(dir_name, obj_name), obj_type, load_design_info=load_design_info)
else:
# If obj is not None, does recursive loading depending on the obj type.
if obj in ("list", "tuple"):
# Object is list or tuple.
# Fetches each element according to the number index to preserve orders.
result = []
# Order index is a number appended to the front.
elements = sorted(pickles + directories, key=lambda x: int(x.split("_")[0]))
# Recursively loads elements.
for element in elements:
if ".pkl" in element:
result.append(dill.load(open(os.path.join(dir_name, element), "rb")))
else:
result.append(
load_obj(
os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
)
)
if obj == "tuple":
result = tuple(result)
return result
elif obj == "dict":
# Object is a dictionary.
# Fetches keys and values recursively.
result = {}
elements = pickles + directories
keys = [element for element in elements if "__key__" in element]
values = [element for element in elements if "__value__" in element]
# Iterates through keys and finds the corresponding values.
for element in keys:
if ".pkl" in element:
key = dill.load(open(os.path.join(dir_name, element), "rb"))
else:
key = load_obj(
os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
)
# Value name could be either with .pkl or a directory.
value_name = element.replace("__key__", "__value__")
if ".pkl" in value_name:
value_name_alt = value_name.replace(".pkl", "")
else:
value_name_alt = value_name + ".pkl"
# Checks if value name is in the dir.
if (value_name not in values) and (value_name_alt not in values):
raise FileNotFoundError(f"Value not found for key {key}.")
value_name = value_name if value_name in values else value_name_alt
# Gets the value.
if ".pkl" in value_name:
value = dill.load(open(os.path.join(dir_name, value_name), "rb"))
else:
value = load_obj(
os.path.join(dir_name, value_name), obj_types[value_name], load_design_info=load_design_info
)
# Sets the key, value pair.
result[key] = value
return result
elif obj == "ordered_dict":
# Object is OrderedDict.
# Fetches keys and values according to the number index to preserve orders.
result = OrderedDict()
# Order index is a number appended to the front.
elements = sorted(pickles + directories, key=lambda x: int(x.split("_")[0]))
# elements = pickles + directories
keys = [element for element in elements if "__key__" in element]
values = [element for element in elements if "__value__" in element]
# Iterates through keys and finds the corresponding values.
for element in keys:
if ".pkl" in element:
key = dill.load(open(os.path.join(dir_name, element), "rb"))
else:
key = load_obj(
os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
)
value_name = element.replace("__key__", "__value__")
# Value name could be either with .pkl or a directory.
if ".pkl" in value_name:
value_name_alt = value_name.replace(".pkl", "")
else:
value_name_alt = value_name + ".pkl"
# Checks if value name is in the dir.
if (value_name not in values) and (value_name_alt not in values):
raise FileNotFoundError(f"Value not found for key {key}.")
value_name = value_name if value_name in values else value_name_alt
# Gets the value.
if ".pkl" in value_name:
value = dill.load(open(os.path.join(dir_name, value_name), "rb"))
else:
value = load_obj(
os.path.join(dir_name, value_name), obj_types[value_name], load_design_info=load_design_info
)
# Sets the key, value pair.
result[key] = value
return result
elif inspect.isclass(obj):
# Object is a class instance.
# Creates the class instance and sets the attributes.
# Some class has required args during initialization,
# these args are pulled from attributes.
init_params = list(inspect.signature(obj.__init__).parameters) # init args
elements = pickles + directories
# Gets the attribute names and their values in a dictionary.
values = {}
for element in elements:
if ".pkl" in element:
values[element.split(".")[0]] = dill.load(open(os.path.join(dir_name, element), "rb"))
else:
values[element] = load_obj(
os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
)
# Gets the init args from values.
init_dict = {key: value for key, value in values.items() if key in init_params}
# Some attributes has a "_" at the beginning.
init_dict.update(
{key[1:]: value for key, value in values.items() if (key[1:] in init_params and key[0] == "_")}
)
# ``design_info`` does not have column_names attribute,
# which is required during init.
# The column_names param is pulled from the column_name_indexes attribute.
# This can be omitted once we allow dumping @property attributes.
if "column_names" in init_params:
init_dict["column_names"] = values["column_name_indexes"].keys()
# Creates the instance.
result = obj(**init_dict)
# Sets the attributes.
for key, value in values.items():
setattr(result, key, value)
return result
else:
# Raises an error if the object is not recognized.
# This typically does not happen when the source file is dumped
# with the `dump_obj` function.
raise ValueError(f"Object {obj} is not recognized.")
Thank you for raising this issue and providing an alternative. My colleague @sayanpatra was working on model dumps this quarter. I'll ask him to take a look at it.
According to the documentation
However trying to make predictions on new data using a model restored via
forecaster.load_forecast_result(path, load_design_info=False)
leads to the following error.This totally makes sense when looking at
greykite.algo.common.ml_models.predict_ml
as the variablex_design_info
is used bypatsy
to build the design matrix (see here).On the other hand, dumping
design_info
does not only imply dealing with a bigger artifact, but may be impossible due to system limitations on the generated filename.As an example, this is what happens in my case.
Any ideas on how to work around this issue?