microsoft / LightGBM

A fast, distributed, high performance gradient boosting (GBT, GBDT, GBRT, GBM or MART) framework based on decision tree algorithms, used for ranking, classification and many other machine learning tasks.
https://lightgbm.readthedocs.io/en/latest/
MIT License
16.69k stars 3.83k forks source link

Performance Gap between Non-Distributed and Distributed Lightgbm when Data is Sorted on Label #5025

Open rudra0713 opened 2 years ago

rudra0713 commented 2 years ago

Hi, I have a binary classification dataset where labels are sorted (I know, it's against standard ML practice to have data sorted, but the question is in the spirit of understanding Distributed LightGBM better). When I trained a non-distributed LightGBM and distributed LightGBM on this dataset, I observed a large gap in accuracy when I tested on the same dataset (0.68 vs 0.5). I checked the data partitions for the distributed LGBM, since the labels are fully sorted, almost all of the partitions have only one label. However, when I shuffle the dataset, performance are quite similar between the 2 models.

If this is not the expected behavior, I can share a reproducible code. But if this is the expected behavior, how would Dis. LGBM deal with highly imbalanced datasets. For example, a dataset with 10k rows, where 9k rows have label 0, and only 1k rows with label 1, it is possible that many partitions will end up with one of the labels.

The following is a snippet of how I am creating the data:

    train_data_len = 10000
    X = pd.DataFrame(np.random.rand(train_data_len, 4), columns=list('ABCD'))
    y = pd.Series([0 for _ in range(train_data_len // 2)] + [1 for _ in range(train_data_len // 2)])
    # y = pd.Series([randint(0, 1) for _ in range(train_data_len)])
    invoke_local_lgbm(X, y)

    X['my_target'] = y
    X = dd.from_pandas(X, npartitions=num_of_workers, sort=True)
    # for part in X.partitions:
        # y_ = part.compute()
        # print("class count for partition ", y_['my_target'].value_counts())

    y = X['my_target']
    X = X.drop(columns=['my_target'])
    invoke_distributed_lgbm(X, y)
jmoralez commented 2 years ago

Thanks for raising this @rudra0713. I can reproduce the issue and I actually sometimes get an error. I used the following:

import dask.dataframe as dd
import lightgbm as lgb
import numpy as np
import pandas as pd
from dask.distributed import Client
from sklearn.metrics import accuracy_score

if __name__ == '__main__':
    results = {}
    client = Client()
    n_workers = len(client.scheduler_info()['workers'])
    train_data_len = 10000
    rng = np.random.RandomState(0)
    X = pd.DataFrame(rng.rand(train_data_len, 4), columns=list('ABCD'))
    for order in ('sorted', 'scrambled'):
        y = (rng.rand(train_data_len) < 0.5).astype('int')
        if order == 'sorted':
            y = np.sort(y)
        reg = lgb.LGBMClassifier(verbosity=-1).fit(X, y)
        pred = reg.predict(X)
        results[f'acc_{order}'] = accuracy_score(y, pred)

        df = X.copy()
        df['y'] = y
        ddf = dd.from_pandas(df, npartitions=n_workers)
        dX, dy = ddf.drop(columns='y'), ddf['y']
        dreg = lgb.DaskLGBMClassifier().fit(dX, dy)
        dpred = dreg.predict(dX).compute()
        results[f'dacc_{order}'] = accuracy_score(dy.compute(), dpred)
    print(results)

And when I try to predict with the distributed model trained on the sorted label I sometimes get:

Exception: ValueError('y contains previously unseen labels: [1]')

One thing that immediately seems odd is that I see:

[LightGBM] [Info] Number of positive: 2470, number of negative: 7530
[LightGBM] [Info] Number of positive: 2470, number of negative: 7530
[LightGBM] [Info] Number of positive: 2470, number of negative: 7530
[LightGBM] [Info] Number of positive: 2470, number of negative: 7530

when there should actually be 4,970 positive.

Versions:

jameslamb commented 2 years ago

Thanks for raising this, and for the tight reproducible example @jmoralez !

In my opinion, in cases where the data are pre-partitioned (like when using lightgbm.dask), LightGBM shouldn't be responsible for re-shuffling data between workers. But we should make the expectations for how you partition your data much clearer.

I'd support changes (at the C++ level, not in the Python package) to raise a more informative error when the distributions of the target have no overlap (for regression) or each partition does not have sufficient data about each of the target classes (for classification).

For cases where you've set pre_partition=False and ask LightGBM to redistribute the data for your, I think it will ensure that each slice of the data has a roughly representative sample of the target, but I'm not sure about that. Would have to investigate this section of the code more:

https://github.com/microsoft/LightGBM/blob/83a41dabec3c97ae99a4820721e23d7422e7144e/src/io/dataset_loader.cpp#L542-L576

rudra0713 commented 2 years ago

Thanks a lot to both of you.

@jmoralez I have also seen the exception Exception: ValueError('y contains previously unseen labels: [1]') in my system when experimenting with sorted labels sometimes.My LightGBM version is 3.3.2.99 built from the source code.

@jameslamb I have not experimented with pre_partition=False before. Currently, I use the following line to create as many partitions as the number of workers.

X = dd.from_pandas(X, npartitions=num_of_workers)

I wonder if I use pre_partition=False, whether LightGBM will decide to partition the data into a subset of workers. If that happens, I believe there can be an issue when using the machines parameter in conjunction. I will do the experiment and let you know.

jameslamb commented 2 years ago

I will do the experiment and let you know.

I just want to be sure you understand....setting pre_partition will not have any effect on lightgbm.dask. That parameter will only affect distributed training using the CLI or lgb.train() in the Python package.

rudra0713 commented 2 years ago

Before asking about pre_partition, let me clarify this first: The issue that Distributed lightgbm may fail or underperform if each partition does not see all the target classes is not a bug, rather it is expected from the user to create partitions in a way that satisfied that property, is this statement correct?

Now, regarding pre_partition, if it does not affect lightgbm.dask, then I do not fully understand the role of the pre_partition parameter. My understanding was, if I choose to use pre_partition=False, then I will not use the n_partitions parameter anymore, and let Lightgbm distribute the data among it's workers. If that's not the intended purpose of pre_partition, can you give me a sample example of how to use this parameter?

jameslamb commented 2 years ago

Distributed LightGBM may fail or underperform if each partition does not see all the target classes is not a bug, rather it is expected from the user to create partitions in a way that satisfied that property, is this statement correct?

Correct, it's expected behavior. @shiyu1994 or @guolinke please correct me if I'm not right about that.

can you give me a sample example of how to use this parameter?

Never use this parameter with lightgbm.dask or if you are using Dask.

You might use pre_partition=False if you are using the LightGBM CLI to run distributed training across multiple machines, where a complete copy of the data is stored in a file on each machine.

See https://github.com/microsoft/LightGBM/issues/3835#issuecomment-961607939 for more details. And since you are just trying to learn about distributed LightGBM, I recommend that you read all of the discussion in #3835.

rudra0713 commented 2 years ago

Thanks for the clarification @jameslamb.

shiyu1994 commented 2 years ago

@rudra0713 Thanks for using LightGBM, and the reproducible example. The only known reason that causes a gap of performance between distributed and single process versions is that, different distributions of feature values cause different bin boundaries for feature histograms, as we've discussed in https://github.com/microsoft/LightGBM/issues/3835#issuecomment-961607939.

I think it is necessary to support a uniform bin boundaries across different processes in distributed training.

But, I don't think that should be the root cause for the significant gap in your example. Because it seems that the features are generated purely randomly. So the distribution of feature values across processes should be similar.

This should be investigated further.