Open hardyshoo opened 4 weeks ago
Hey @hardyshoo, thanks for using LightGBM and for the reproducible example.
When early stopping happens the model is trimmed to the best iteration, so in your example if you run lgb.booster_.current_iteration()
you'll see that it matches lgb.best_iteration_
, which means that it can only predict up to that number of iterations.
Please let us know if you have further doubts.
Hi @jmoralez , thank you for the response. Can I ask why this is the case? Given that the model has to be trained up to bestiteration + the number of specified early stopping rounds anyways, I assume that means the extra iterations are removed before the booster object is returned. I can't think of what drawbacks there would be to not removing those iterations, and I think having the option to use these extra iterations if desired would be useful.
The trim happens here: https://github.com/microsoft/LightGBM/blob/5151fe85f08e5dccff7d48242dddace51f9c8ede/python-package/lightgbm/engine.py#L349-L350 and that argument isn't available in the scikit-learn API, but if you can use the training API you can keep those iterations, e.g.
from sklearn.datasets import load_breast_cancer
from lightgbm import Dataset, early_stopping, train
from sklearn.metrics import log_loss
X,y = load_breast_cancer(return_X_y = True)
train_ds = Dataset(X[:400], y[:400])
valid_ds = train_ds.create_valid(X[400:], y[400:])
bst = train(
{'objective': 'binary'},
train_ds,
num_boost_round=1000,
valid_sets=[valid_ds],
callbacks=[early_stopping(10)],
keep_training_booster=True,
)
for i in range(bst.best_iteration-5, bst.best_iteration+5):
print(log_loss(y[400:], bst.predict(X[400:], num_iteration=i)))
Description
When using LGBMClassifier to train a model that is stopped by early stopping and the predict_proba method is called with a value for 'num_iteration' greater than the value stored in the model's 'bestiteration' attribute, the method appears to still use the number of iterations specified by 'bestiteration', rather than the number passed in the parameter.
Reproducible example
Environment info
LightGBM version 4.5