microsoft / LightGBM

A fast, distributed, high performance gradient boosting (GBT, GBDT, GBRT, GBM or MART) framework based on decision tree algorithms, used for ranking, classification and many other machine learning tasks.
https://lightgbm.readthedocs.io/en/latest/
MIT License
16.59k stars 3.83k forks source link

decay_rate parameter in booster.refit has no effect #2253

Closed qjcstc0321 closed 5 years ago

qjcstc0321 commented 5 years ago

I noticed that decay_rate parameter in lightgbm.booster.refit function was not called in source code. Through my test, tuning this parameter also did not affect final prediction result. Is this parameter invalid, or did I use it in a wrong way?

Environment info

Operating System:

CPU/GPU model:

C++/Python/R version: Python 3.6

LightGBM version or commit hash: 2.2.3

Reproducible examples

`import lightgbm as lgb import numpy as np import pandas as pd from sklearn.model_selection import train_test_split from sklearn.datasets import load_breast_cancer

load dataset

X = load_breast_cancer().data y = load_breast_cancer().target feature_names = list(load_breast_cancer().feature_names) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=100) train_data = lgb.Dataset(X_train, y_train, feature_name=feature_names) test_data = lgb.Dataset(X_test, y_test, feature_name=feature_names)

train and update model

params = {'learning_rate':0.02, 'objective':'binary'} params.update({'bagging_fraction': 0.9, 'feature_fraction': 0.7, 'bagging_freq': 1}) params.update({'max_depth': 5, 'min_data_in_leaf': 10}) params.update({'lambda_l2': 0.3}) params.update({'metric': 'auc', 'use_missing': False}) clf1 = lgb.train(params, train_set=train_data, valid_sets=(train_data, test_data), valid_names=['ins', 'oos'], num_boost_round=15, verbose_eval=1) clf2 = clf1.refit(X_test, y_test, decay_rate=0.99) clf3 = clf1.refit(X_test, y_test, decay_rate=0.5)

inference

score1 = clf1.predict(X_test) score2 = clf2.predict(X_test) score3 = clf3.predict(X_test) print(score1.mean()) # output 0.6342387193103918 print(score2.mean()) # output 0.6256702147069053 print(score3.mean()) # output 0.6256702147069053`

guolinke commented 5 years ago

this indeed is a silly bug. fixed in #2254, please have a try.