Closed zaneselvans closed 2 years ago
sample_weight
The hyperparameter grid and sample weights are inserted at different steps in the process.
The hyperparameter space to explore is defined within GridSearchCV
:
# Define hyperparameter space to search:
grid = GridSearchCV(
pipe,
param_grid={
"hist_gbr__max_depth": [3, 5, 7],
"hist_gbr__max_leaf_nodes": [7, 15, 31],
"hist_gbr__learning_rate": [0.1, 0.3, 0.9],
},
cv=KFold(n_splits=4, shuffle=True, random_state=0),
n_jobs=-1,
)
While model fitting parameters that are uniform across that whole search space are passed in during the fit()
step:
# Weight samples by `sample_weight` during training.
result = grid.fit(
X=frc_train_test[cat_cols+num_cols],
y=frc_target,
hist_gbr__sample_weight=sample_weight,
)
plant_id_eia
or other criteriaIt looks like GroupShuffleSplit
can be used to partition the test/train data on whatever ID column you want, to avoid overlap when necessary, like this, with the gss_split
generator swapping in for e.g. KFold
gss = GroupShuffleSplit(test_size=0.2, n_splits=5, random_state=0)
# Split based on Plant ID to avoid plant information leaking between test / training
gss_split = gss.split(frc_train_test, groups=frc_train_test["plant_id_eia"])
param_grid = {
"hist_gbr__max_depth": [3, 5, 7],
"hist_gbr__max_leaf_nodes": [7, 15, 31],
"hist_gbr__learning_rate": [0.1, 0.3, 0.9],
}
# Define hyperparameter space to search:
grid = GridSearchCV(pipe, param_grid=param_grid, cv=gss_split, n_jobs=-1)
It's been a while since I played with scikit-learn, and I need to re-familiarize myself with the tools and methods for composing these moving parts together, in the context of the EIA-923 fuel price imputations. I'm sure some of this is super basic and I just don't understand how it works.
Tasks / Issues
sample_weight
and model hyper parameters into the cross validation.plant_id_eia
from each other.test_score
fromGridSearchCV
(outlier values were throwing this way off. Clipping the most egregious ones and using the median error made it much better behaved)frc_eia923
,coalmine_eia923
andplants_entity_eia
values.fuel_cost_per_mmbtu
values in thefrc_eia923
table. See #1712NA
values in categorical columns to pass through theOrdinalEncoder
so this HistGBR model can access them.