dmlc / xgboost

Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow
https://xgboost.readthedocs.io/en/stable/
Apache License 2.0
26.16k stars 8.71k forks source link

Federated Learning RFC #7778

Closed rongou closed 1 year ago

rongou commented 2 years ago

Here is a design proposal for initial support of Federated Learning. Please take a look and share you feedback. Thanks!

Partially addresses #5430. @piiswrong @hcho3 @trivialfis @RAMitchell @tqchen @podcastinator

Motivation

In many machine learning use cases, data owners (hospitals, banks, etc.) cannot share their private data with each other or with third parties, due to privacy regulations such as GDPR, HIPAA, and CCPA. Federated learning can be used to train a joint model without the need to aggregate data in a centralized location.

For certain types of problems (tabular data, time series forecasting), XGBoost is shown to be competitive against deep learning models, while requiring less tuning and computation. With the additional overhead in a federated setting, XGBoost is well suited for real world applications.

Goals

For the initial version, the goals are limited and modest:

Non-Goals

Assumptions

Federated learning can refer to a wide range of problems, from a few close partners collaborating in a relatively secure environment, to millions of mobile devices connected over the Internet without encryption. For this design we assume we will start from the former, and can implement more security and privacy features with additional effort.

A typical use case is not that much different from the current distributed setting:

Risks

The current XGBoost codebase is fairly complicated and hard to modify. Some code refactoring needs to happen first, before support for federated learning can be added. Care must be taken to not break existing functionality, or make regular training harder.

Design

Distributed XGBoost is a good starting point for horizontal federated learning. Data is split evenly between workers, which collaboratively train a joint model. For histogram based tree construction methods, only histograms of gradients are communicated between workers. However, unlike a distributed environment, in a federated setting, we cannot assume workers will be connected directly to each other. Instead, we will have a trusted central aggregator, and connect workers in a star topology. Federated Learning with XGBoost To train a federated model:

Histogram based tree construction in XGBoost is a two step process. First, in the preprocessing step, each input feature is divided into quantiles and put into bins. Second, trees are constructed using histograms of gradients. For both steps, data needs to be aggregated across all workers. In distributed training, this aggregation is done using allreduce (rabit for CPU, NCCL for GPU). In the federated setting, we need to simulate allreduce using RPC calls (data is sent to the aggregator from each worker, once aggregated together, it’s sent back to each worker).

One possible approach is to refactor the existing communication code into an abstraction layer, and provide different implementations depending on whether we are training with CPUs or GPUs in a distributed environment, or in a federated environment. XGBoost Communicators The communicator is a simple interface with two methods, init() and allreduce(). Note that these communicators need to be in C++.

Once this is done, some Python glue code can be added for setting up the federated learning environment.

Rough sketch of the federated worker code (these APIs need to be hashed out more):

client = nvflare.Client(hostname, port)
dtrain = xgb.FederatedDMatrix(client, 'data/path/data.txt.train')
param = {...}
num_round = 100
output = xgb.federated.train(
  client,
  param,
  dtrain,
  num_round,
  evals=[(dtrain, "train")],
)

Alternatives Considered

In theory we could add a new engine in Rabit, or implement a net plugin in NCCL, to support federated communication. But either approach would limit federated XGBoost to either only CPU or GPU training. In addition, even with a net plugin, NCCL still needs to open socket connections between peers during init and reconnect, so it’s not suitable for the federated environment.

An alternative is to add more fine-grained C APIs and Python wrappers to XGBoost, so that tree construction and gradient aggregation can be controlled from Python. This actually fits better with the current NVFlare design, but it would require significant refactoring of XGBoost, and probably more risky than factoring out the communicator.

Finally, it is always possible to implement gradient boosting directly on top of a federated learning platform (as in FATE), but this would lose the ecosystem benefits of XGBoost.

hcho3 commented 2 years ago

I like the idea of abstracting the AllReduce primitive to support multiple AllReduce backends, including NVFlare. One possible source of complication is the data sketching step, where workers exchange details of the data matrix, including the number of features as well as information about data distribution (quantiles). @trivialfis is knowledgeable about this step.

rongou commented 2 years ago

@hcho3 yeah with this initial version there would probably be some leakage in the quantile sketching step. Need to beef up security and privacy with further efforts.

rongou commented 1 year ago

This is implemented in the 1.7.0 release. Closing this and will create additional RFCs for other federated learning related issues.