dask / dask-xgboost

BSD 3-Clause "New" or "Revised" License
162 stars 43 forks source link

Add evals result #60

Closed kylejn27 closed 4 years ago

kylejn27 commented 4 years ago

What

Closes https://github.com/dask/dask-xgboost/issues/59

Adds local_evals dictionary to each workers xgb.train method. When all workers return and are aggregated, the evals_result dict that is passed into the dask-xgboost.train method is updated with the resulting evaluation history.

Why

It is desirable to recall the evaluation at each iteration of the training process after training the model. This is a feature that exists in dmlc/xgboost that would be nice to have in dask-xgboost

Test

import dask
import dask.array as da
import numpy as np
import pandas as pd
from dask.distributed import Client, LocalCluster
from sklearn.datasets import load_digits, load_iris
from sklearn.model_selection import train_test_split

import dask_xgboost as dxgb
import xgboost as xgb

df = pd.DataFrame(
    {"x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "y": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]}
)
labels = pd.Series([1, 0, 1, 0, 1, 0, 1, 1, 1, 1])

X = df.values
y = labels.values

X2 = da.from_array(X, 5)
y2 = da.from_array(y, 5)

cluster = LocalCluster()
c = Client(cluster)

a = dxgb.XGBRegressor(eval_metric="rmse", random_state=1, seed=1, verbosity=0)
a.fit(X2, y2, eval_set=[(X, y)])

b = xgb.XGBRegressor(eval_metric="rmse", random_state=1, seed=1, verbosity=0)
b.fit(X, y, eval_set=[(X, y)])

c = xgb.dask.DaskXGBRegressor(eval_metric='rmse', random_state=1, seed=1, verbosity=0)
c.fit(X2, y2, eval_set=[(X2, y2)])

assert a.evals_result() == b.evals_result()
assert a.evals_result() == c.evals_result()
assert b.evals_result() == c.evals_result()
>>> print(a.evals_result())
{'validation_0': {'rmse': [0.461261, 0.425567, 0.392674, 0.36236, 0.334421, 0.308667, 0.284927, 0.263039, 0.242859, 0.22425, 0.207089, 0.191262, 0.176663, 0.163196, 0.150772, 0.139231, 0.128591, 0.118781, 0.109739, 0.101406, 0.093727, 0.086653, 0.080138, 0.074139, 0.068618, 0.063538, 0.058867, 0.054574, 0.050631, 0.047013, 0.043661, 0.04055, 0.037662, 0.034981, 0.032493, 0.030182, 0.028037, 0.026046, 0.024197, 0.02248, 0.020885, 0.019405, 0.01803, 0.016754, 0.015568, 0.014467, 0.013444, 0.012494, 0.011612, 0.010792, 0.010031, 0.009323, 0.008666, 0.008056, 0.007489, 0.006962, 0.006472, 0.006017, 0.005595, 0.005202, 0.004837, 0.004498, 0.004182, 0.003889, 0.003621, 0.003367, 0.003132, 0.002913, 0.002709, 0.00252, 0.002344, 0.00218, 0.002033, 0.001892, 0.00176, 0.001637, 0.00153, 0.001428, 0.001336, 0.001243, 0.001164, 0.001083, 0.001017, 0.000956, 0.000891, 0.00084, 0.000794, 0.000752, 0.000715, 0.000708, 0.000702, 0.000697, 0.000664, 0.00066, 0.000657, 0.000654, 0.000652, 0.00065, 0.000649, 0.000648]}}
TomAugspurger commented 4 years ago

Thanks @kylejn27. Looks great.