microsoft / LightGBM

A fast, distributed, high performance gradient boosting (GBT, GBDT, GBRT, GBM or MART) framework based on decision tree algorithms, used for ranking, classification and many other machine learning tasks.
MIT License
16.7k stars 3.83k forks source link

[dask] Dask LightGBM best practices #5840

Open adfea9c0 opened 1 year ago

adfea9c0 commented 1 year ago

Hey! I've been experimenting a bit with Dask LightGBM and I have a couple questions. Apologies if this is not the right forum for questions like this.

1) First, what is the most recommended to pass data into LightGBM? Initially I followed the advice here [1] and persisted my data before fitting the DaskLGBMRegressor. Specifically, I would roughly use something like:

df_train = dask.dataframe.read_parquet("/my/glob/*/*.parquet", columns=features+[resp], engine="pyarrow")
X_train = df_train[features].to_dask_array(lengths=True)
y_train = df_train[resp].to_dask_array(lengths=True)
X_train, y_train = X_train.persist(), y_train.persist()
_ = dask.distributed.wait([X_train, y_train])

model = lgb.DaskLGBMRegressor(**blah), y_train)

But based on early observations I'm finding this to be slower than just dropping the persist and to_dask_array calls and giving LightGBM the read_parquet future. I don't really understand why persisting would be helpful anyway since from what I understand, LightGBM will first do some work to rearrange the data to ensure that X[i] and y[i] are on the same worker for all i [2], so doing any work to load data on the workers prior seems wasted? I.e. I now run something like this:

df_train = dask.dataframe.read_parquet("/my/glob/*/*.parquet", columns=features+[resp], engine="pyarrow")
model = lgb.DaskLGBMRegressor(**blah)[features], df_train[resp])

2) I'm having trouble doing some kind of logging with Dask LightGBM. I'd be happy to know just how many iterations in training is, but passing in something like callbacks=[lambda env: print(env.iteration)] doesn't show anything in either scheduler or worker logs. Can I do some sort of logging in the dask regressor?

[1] [2]

adfea9c0 commented 1 year ago

Tangentially I'm wondering if there is some way to affect load balancing? I think either Dask or LightGBM is currently causing a very unbalanced distribution of data for my use case.

I'm loading 500 parquet files of about 1GB each onto 160 workers each with 50GB of memory. I imagine that would be plenty but some workers get very little data while others get so much it spills to disk. How does this happen? How can I fix it ?
