microsoft / LightGBM

A fast, distributed, high performance gradient boosting (GBT, GBDT, GBRT, GBM or MART) framework based on decision tree algorithms, used for ranking, classification and many other machine learning tasks.
https://lightgbm.readthedocs.io/en/latest/
MIT License
16.69k stars 3.83k forks source link

[python package] init_score doesn't appear to work #1000

Closed j-mark-hou closed 7 years ago

j-mark-hou commented 7 years ago

setting the init_score in an lgb.Dataset object doesn't appear to change anything. I expect that this should behave in a manner similar to xgboost's set_base_margin.

In particular, you can try training with two lgb.Dataset objects, one with set_init_score, and one without, and the results are identical. Tested with 'regression_l2' and 'poisson' objectives.

It also seems strange that 'init_score' appears nowhere in LightGBM/python-package/lightgbm/engine.py file.

import numpy as np
import lightgbm as lgb

# generate some data
n = 1000
d = 10

traincut = int(.8*n)

X = np.random.randint(10,size=(n,d))
y = X.sum(axis=1) + np.random.normal(size=n)

Xtrain, ytrain, init_score_train = X[:traincut, :-1], y[:traincut], X[:traincut, -1]
Xval, yval, init_score_val = X[traincut:,:-1], y[traincut:], X[traincut:, -1]

print(Xtrain.shape, ytrain.shape, init_score_train.shape)
print(Xval.shape, yval.shape, init_score_val.shape)

params = {
    'num_threads':5,
    'metric':'rmse',
#     'objective':'poisson',
    'objective':'regression_l2',
    'learning_rate':.01,
}

# 1. don't set any init_score
dat_train = lgb.Dataset(Xtrain, ytrain)
dat_val = lgb.Dataset(Xval, yval)
gbm = lgb.train(params, 
                dat_train, 
                num_boost_round=7000, 
                early_stopping_rounds=100,
                valid_sets=[dat_train, dat_val], 
                verbose_eval=100)

# 2. try it with setting init_score
dat_train = lgb.Dataset(Xtrain, ytrain)
dat_train.set_init_score(init_score_train)
dat_val = lgb.Dataset(Xval, yval)
dat_val.set_init_score(init_score_val)
gbm = lgb.train(params, 
                dat_train, 
                num_boost_round=7000, 
                early_stopping_rounds=100,
                valid_sets=[dat_train, dat_val], 
                verbose_eval=100,
                )

# in both cases, the optimal # of iterations and the sequence of training/validation RMSEs are identical
guolinke commented 7 years ago

@wxchan any ideas about this ? I feel like needing a test for this.

wxchan commented 7 years ago

@j-mark-hou @guolinke init_score is only used for prediction now. It's not used in _lazy_init.

guolinke commented 7 years ago

@wxchan I think it can be enabled in _lazy_init

j-mark-hou commented 7 years ago

is there an example / explanation somewhere for how this feature is supposed to work? Or maybe you can point me to some parts of the code where this is implemented?

olofer commented 7 years ago

Some objectives support "auto init score" with the boost_from_average option. I think that explicitly setting init_score should override auto init, but otherwise leave things as they are. I think I suggested something along that line in the code comments. It may suffice to just have ObtainAutomaticInitialScore first look for explicit init_score parameter otherwise proceed as before...

https://github.com/Microsoft/LightGBM/blob/1ef3d43ecd2068811273eb9226572aec694ee000/src/boosting/gbdt.cpp#L288-L336

wxchan commented 7 years ago

@j-mark-hou it's not working. you can add self.init_score=None after https://github.com/Microsoft/LightGBM/blob/master/python-package/lightgbm/basic.py#L606 and self.set_init_score(self.init_score) after https://github.com/Microsoft/LightGBM/blob/master/python-package/lightgbm/basic.py#L695 to see if it works.

guolinke commented 7 years ago

refer to #1007

github-actions[bot] commented 1 year ago

This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.