wenh06 / fl-sim

A Simple Simulation Framework for Federated Learning Based on PyTorch
MIT License
6 stars 1 forks source link

A Simple Simulation Framework for Federated Learning Based on PyTorch

formatting Docker CI PyTest codecov

中文版

Project Links:

This repository is migrated from fl_seminar

The main part of this code repository is a standalone simulation framework for federated training.

Installation

Run the following command to install the package:

pip install git+https://github.com/wenh06/fl-sim.git

or clone the repository and run the following command in the root directory of the repository:

pip install -e .

Alternatively, one can use the Docker image wenh06/fl-sim to run the code. The image is built with the Docker Image CI action. To pull the image, run the following command:

docker pull wenh06/fl-sim

For the usage (interactive mode), run the following command:

docker run -it wenh06/fl-sim bash

For more advanced usages (e.g., run a script), refer to the Docker official documentation.

Usage Examples

Click to expand! The following code snippet shows how to use the framework to train a model on the `FedProxFEMNIST` dataset using the `FedProx` algorithm. ```python from fl_sim.data_processing.fedprox_femnist import FedProxFEMNIST from fl_sim.algorithms.fedprox import ( FedProxServer, FedProxClientConfig, FedProxServerConfig, ) # create a FedProxFEMNIST dataset ds = FedProxFEMNIST() # choose a model model = ds.candidate_models["cnn_femmist_tiny"] # set up the server and client configurations server_config = FedProxServerConfig(200, ds.DEFAULT_TRAIN_CLIENTS_NUM, 0.7) client_config = FedProxClientConfig(ds.DEFAULT_BATCH_SIZE, 30) # create a FedProxServer object s = FedProxServer(model, ds, server_config, client_config) # normal centralized training s.train_centralized() # federated training s.train_federated() ```

Algorithms Implemented

Algorithm Paper Upstream Action Status Validity on Standard Test
FedAvg[^1] AISTATS2017 N/A test-fedopt :heavy_check_mark:
FedOpt[^2] arXiv:2003.00295 N/A test-fedopt :heavy_check_mark:
FedProx MLSys2020 GitHub test-fedprox :heavy_check_mark: :question:
pFedMe NeurIPS2020 GitHub test-pfedme :interrobang:
FedSplit NeurIPS2020 N/A test-fedsplit :heavy_check_mark: :question:
FedDR NeurIPS2021 GitHub test-feddr :interrobang:
FedPD IEEE Trans. Signal Process GitHub test-fedpd :interrobang:
SCAFFOLD PMLR N/A test-scaffold :heavy_check_mark: :question:
ProxSkip PMLR N/A test-proxskip :heavy_check_mark: :question:
Ditto PMLR GitHub test-ditto :heavy_check_mark:
IFCA NeurIPS2020 GitHub test-ifca :heavy_check_mark:
pFedMac arXiv:2107.05330 N/A test-pfedmac :interrobang:
FedDyn ICLR2021 N/A test-feddyn :question:
APFL arXiv:2003.13461 N/A test-apfl :question:

[^1]: FedAvg is implemented as a special case of FedOpt. [^2]: Including FedAdam, FedYogi, FedAdagrad.

Standard Test Status Images:

Client sample ratio 10% Client sample ratio 30% Client sample ratio 70% Client sample ratio 100%

Main Modules

Nodes

Click to expand! `Node`s are the core of the simulation framework. `Node` has two subclasses: `Server` and `Client`. The `Server` class is the base class for all servers, which acts as the coordinator of the training process, as well as maintainer of status variables. The `Client` class is the base class for all clients. The abstract base class `Node` provides the following basic functionalities: - `get_detached_model_parameters`: get the model parameters of the node in a detached form. - `compute_gradients`: compute the gradients at specified model parameters (default: current model parameters on the node) using specified data (default: training data on the node). - `get_gradients`: get the gradients, or norm of the gradients, of the model parameters of the node. - `get_norm`: get the norm of given tensors or numpy arrays. - `set_parameters`: set the model parameters of the node. - ~~`aggregate_results_from_csv_log`: aggregate the experiment results from the csv log file.~~ - `aggregate_results_from_json_log`: aggregate the experiment results from the json log file. and abstract methods or properties that need to be implemented by subclasses: - `communicate`: communicate procedure with other (type of) nodes in each iteration. - `update`: updating procedure in each iteration. - `required_config_fields` (property): required fields in the configuration class, which is used to check the validity of the configuration in the `_post_init` method. - `_post_init`: post-initialization procedure, called in the end of `__init__` method, used in companion with `required_config_fields` to check the validity of the configuration. The `Server` class has signature ```python Server( model: torch.nn.modules.module.Module, dataset: fl_sim.data_processing.fed_dataset.FedDataset, config: fl_sim.nodes.ServerConfig, client_config: fl_sim.nodes.ClientConfig, lazy: bool = False, ) -> None ``` providing the following additional functionalities or properties: - `_setup_clients`: setup (initialize) the clients, and allocate devices to them. - `_sample_clients`: sample a subset of clients from the client pool. - `_communicate`: execute the `communicate` method of the clients, and increment the global communication counter (`_num_communications`). - `_update`: checks the validity messages (`_received_messages`) received from the clients, execute the `update` method of the server, and finally clears the received messages. - `train`: the main training procedure, which calls one of `train_centralized`, `train_federated`, `train_local` depending on the argument `mode` passed to this method. - `train_centralized`: centralized training procedure, mainly used for comparison. - `train_federated`: federated training procedure, which calls the `_communicate` (to clients), wait for the clients to execute `_update` and `_communicate`, and finally calls `_update` to update the server. - `train_local`: local training procedure, which calls the `train` method of the clients **without** communication with the server. - `add_parameters`: addition of parameters (values) to the server model parameters. - `avg_parameters`: averaging the model parameters in the received messages. - `update_gradients`: update the gradients of the server model parameters using the received gradients. - `get_client_data`: helper function to get the data of the clients. - `get_client_model`: helper function to get the model of the clients. - `get_cached_metrics`: helper function to get the cached aggregated metrics of the clients stored on the server. - `_reset`: reset the server to the initial state. Before carrying out a new training process, the flag `_complete_experiment` will be checked. If it is `True`, this method will be called to reset the server. - `is_convergent` (property): check whether the training process has converged. Currently, this property is **NOT** fully implemented. and **abstract properties that need to be implemented by subclasses**: - `client_cls`: the client class used when initializing the clients via `_setup_clients`. - `config_cls`: a dictionary of configuration classes for the server and clients, used in `__init__` method. - `doi`: the DOI of the paper that proposes the algorithm. The `Client` class has signature ```python Client( client_id: int, device: torch.device, model: torch.nn.modules.module.Module, dataset: fl_sim.data_processing.fed_dataset.FedDataset, config: fl_sim.nodes.ClientConfig, ) -> None ``` providing the following additional functionalities: - `_communicate`: execute the `communicate` method of the client, increment the communication counter (`_num_communications`), and clears the cached local evaluation results. - `_update`: execute the `update` method of the client, and clears the received messages from the server. - `evaluate`: evaluate the model on the local test data. - `get_all_data`: helper function to get all the data of the client. and **abstract methods that need to be implemented by subclasses**: - `train`: training procedure of the client. The configuration classes `ServerConfig` and `ClientConfig` are used to store the configuration of the server and clients, respectively. These two classes are similar to a [`dataclass`](https://docs.python.org/3/library/dataclasses.html), but accept arbitrary additional fields. The signature of `ServerConfig` is ```python ServerConfig( algorithm: str, num_iters: int, num_clients: int, clients_sample_ratio: float, txt_logger: bool = True, csv_logger: bool = False, json_logger: bool = True, eval_every: int = 1, verbose: int = 1, **kwargs: Any, ) -> None ``` and the signature of `ClientConfig` is ```python ClientConfig( algorithm: str, optimizer: str, batch_size: int, num_epochs: int, lr: float, verbose: int = 1, **kwargs: Any, ) -> None ``` To implement a **new federated algorithm**, one needs to implement a subclass of `Server`, `Client`, `ServerConfig`, and `ClientConfig`. For example, the following implementation of FedProx is provided in the file [fedprox](fl_sim/algorithms/fedprox/_fedprox.py):
Click to expand! ```python import warnings from copy import deepcopy from typing import List, Dict, Any import torch from torch_ecg.utils.misc import add_docstring from tqdm.auto import tqdm from fl_sim.nodes import Server, Client, ServerConfig, ClientConfig, ClientMessage from fl_sim.algorithms import register_algorithm @register_algorithm("FedProx") class FedProxServerConfig(ServerConfig): """Server config for the FedProx algorithm. Parameters ---------- num_iters : int The number of (outer) iterations. num_clients : int The number of clients. clients_sample_ratio : float The ratio of clients to sample for each iteration. vr : bool, default False Whether to use variance reduction. **kwargs : dict, optional Additional keyword arguments: - ``log_dir`` : str or Path, optional The log directory. If not specified, will use the default log directory. If not absolute, will be relative to the default log directory. - ``txt_logger`` : bool, default True Whether to use txt logger. - ``json_logger`` : bool, default True Whether to use json logger. - ``eval_every`` : int, default 1 The number of iterations to evaluate the model. - ``visiable_gpus`` : Sequence[int], optional Visable GPU IDs for allocating devices for clients. Defaults to use all GPUs if available. - ``extra_observes`` : List[str], optional Extra attributes to observe during training. - ``seed`` : int, default 0 The random seed. - ``tag`` : str, optional The tag of the experiment. - ``verbose`` : int, default 1 The verbosity level. - ``gpu_proportion`` : float, default 0.2 The proportion of clients to use GPU. Used to similate the system heterogeneity of the clients. Not used in the current version, reserved for future use. """ __name__ = "FedProxServerConfig" def __init__( self, num_iters: int, num_clients: int, clients_sample_ratio: float, vr: bool = False, **kwargs: Any, ) -> None: super().__init__( "FedProx", num_iters, num_clients, clients_sample_ratio, vr=vr, **kwargs, ) @register_algorithm("FedProx") class FedProxClientConfig(ClientConfig): """Client config for the FedProx algorithm. Parameters ---------- batch_size : int The batch size. num_epochs : int The number of epochs. lr : float, default 1e-2 The learning rate. mu : float, default 0.01 Coefficient for the proximal term. vr : bool, default False Whether to use variance reduction. **kwargs : dict, optional Additional keyword arguments: - ``scheduler`` : dict, optional The scheduler config. None for no scheduler, using constant learning rate. - ``extra_observes`` : List[str], optional Extra attributes to observe during training, which would be recorded in evaluated metrics, sent to the server, and written to the log file. - ``verbose`` : int, default 1 The verbosity level. - ``latency`` : float, default 0.0 The latency of the client. Not used in the current version, reserved for future use. """ __name__ = "FedProxClientConfig" def __init__( self, batch_size: int, num_epochs: int, lr: float = 1e-2, mu: float = 0.01, vr: bool = False, **kwargs: Any, ) -> None: optimizer = "FedProx" if not vr else "FedProx_VR" if kwargs.pop("algorithm", None) is not None: warnings.warn( "The `algorithm` argument fixed to `FedProx`.", RuntimeWarning ) if kwargs.pop("optimizer", None) is not None: warnings.warn( "The `optimizer` argument fixed to `FedProx` or `FedProx_VR`.", RuntimeWarning, ) super().__init__( "FedProx", optimizer, batch_size, num_epochs, lr, mu=mu, vr=vr, **kwargs, ) @register_algorithm("FedProx") @add_docstring( Server.__doc__.replace( "The class to simulate the server node.", "Server node for the FedProx algorithm.", ) .replace("ServerConfig", "FedProxServerConfig") .replace("ClientConfig", "FedProxClientConfig") ) class FedProxServer(Server): """Server node for the FedProx algorithm.""" __name__ = "FedProxServer" def _post_init(self) -> None: """ check if all required field in the config are set, and check compatibility of server and client configs """ super()._post_init() assert self.config.vr == self._client_config.vr @property def client_cls(self) -> type: return FedProxClient @property def required_config_fields(self) -> List[str]: return [] def communicate(self, target: "FedProxClient") -> None: target._received_messages = {"parameters": self.get_detached_model_parameters()} if target.config.vr: target._received_messages["gradients"] = [ p.grad.detach().clone() if p.grad is not None else torch.zeros_like(p) for p in target.model.parameters() ] def update(self) -> None: # sum of received parameters, with self.model.parameters() as its container self.avg_parameters() if self.config.vr: self.update_gradients() @property def config_cls(self) -> Dict[str, type]: return { "server": FedProxServerConfig, "client": FedProxClientConfig, } @property def doi(self) -> List[str]: return ["10.48550/ARXIV.1812.06127"] @register_algorithm("FedProx") @add_docstring( Client.__doc__.replace( "The class to simulate the client node.", "Client node for the FedProx algorithm.", ).replace("ClientConfig", "FedProxClientConfig") ) class FedProxClient(Client): """Client node for the FedProx algorithm.""" __name__ = "FedProxClient" def _post_init(self) -> None: """ check if all required field in the config are set, and set attributes for maintaining itermidiate states """ super()._post_init() if self.config.vr: self._gradient_buffer = [ torch.zeros_like(p) for p in self.model.parameters() ] else: self._gradient_buffer = None @property def required_config_fields(self) -> List[str]: return ["mu"] def communicate(self, target: "FedProxServer") -> None: message = { "client_id": self.client_id, "parameters": self.get_detached_model_parameters(), "train_samples": len(self.train_loader.dataset), "metrics": self._metrics, } if self.config.vr: message["gradients"] = [ p.grad.detach().clone() for p in self.model.parameters() ] target._received_messages.append(ClientMessage(**message)) def update(self) -> None: try: self._cached_parameters = deepcopy(self._received_messages["parameters"]) except KeyError: warnings.warn("No parameters received from server") warnings.warn("Using current model parameters as initial parameters") self._cached_parameters = self.get_detached_model_parameters() except Exception as err: raise err self._cached_parameters = [p.to(self.device) for p in self._cached_parameters] if ( self.config.vr and self._received_messages.get("gradients", None) is not None ): self._gradient_buffer = [ gd.clone().to(self.device) for gd in self._received_messages["gradients"] ] self.solve_inner() # alias of self.train() def train(self) -> None: self.model.train() with tqdm( range(self.config.num_epochs), total=self.config.num_epochs, mininterval=1.0, disable=self.config.verbose < 2, ) as pbar: for epoch in pbar: # local update self.model.train() for X, y in self.train_loader: X, y = X.to(self.device), y.to(self.device) self.optimizer.zero_grad() output = self.model(X) loss = self.criterion(output, y) loss.backward() self.optimizer.step( local_weights=self._cached_parameters, variance_buffer=self._gradient_buffer, ) ```
:point_right: [Back to TOC](#a-simple-simulation-framework-for-federated-learning-based-on-pytorch)

Data Processing

Click to expand! The module (folder) [data_processing](fl_sim/data_processing) contains code for data preprocessing, IO, etc. The following datasets are included in this module: 1. [`FedCIFAR`](fl_sim/data_processing/fed_cifar.py) 2. [`FedCIFAR100`](fl_sim/data_processing/fed_cifar.py) 3. [`FedEMNIST`](fl_sim/data_processing/fed_emnist.py) 4. [`FedMNIST`](fl_sim/data_processing/fed_mnist.) 5. [`FedShakespeare`](fl_sim/data_processing/fed_shakespeare.py) 6. [`FedSynthetic`](fl_sim/data_processing/fed_synthetic.py) 7. [`FedProxFEMNIST`](fl_sim/data_processing/fedprox_femnist.py) 8. [`FedProxMNIST`](fl_sim/data_processing/fedprox_mnist.py) 9. [`FedProxSent140`](fl_sim/data_processing/fedprox_sent140.py) Each dataset is wrapped in a class, providing the following functionalities: 1. Automatic data downloading and preprocessing 2. Data partitioning (into clients) via methods `get_dataloader` 3. A list of candidate [models] (#models) via the property `candidate_models` 4. Criterion and method for evaluating the performance of a model using its output on the dataset via the method `evaluate` 5. Several helper methods for data visualization and citation (biblatex format) Additionally, one can get the list of `LIBSVM` datasets via ```python pd.read_html("https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/")[0] ``` **NEW**: Part of the vision datasets support dynamic data augmentation for the train subset. The base class `FedVisionDataset` has signature ```python FedVisionDataset( datadir: Union[str, pathlib.Path, NoneType] = None, transform: Union[str, Callable, NoneType] = "none", ) -> None ``` By setting `transform="none"` (default), the train subset is wrapped with a static `TensorDataset`. By setting `transform=None`, the train subset uses built-in dynamic augmentation, for example `FedCIFAR100` uses `torchvision.transforms.RandAugment`. **NOTE** that most of the federated vision datasets are provided with processed values rather than raw pixels, hence not supporting dynamic data augmentation using `torchvision.transforms`. :point_right: [Back to TOC](#a-simple-simulation-framework-for-federated-learning-based-on-pytorch)

Models

Click to expand! The module (folder) [models](fl_sim/models) contains pre-defined (neural network) models, most of which are very simple: 1. `MLP` 2. `FedPDMLP` 3. `CNNMnist` 4. `CNNFEMnist` 5. `CNNFEMnist_Tiny` 6. `CNNCifar` 7. `RNN_OriginalFedAvg` 8. `RNN_StackOverFlow` 9. `RNN_Sent140` 10. `ResNet18` 11. `ResNet10` 12. `LogisticRegression` 13. `SVC` 14. `SVR` Most models are proposed or suggested by previous literature. One can call the `module_size` or `module_size_` properties to check the size (in terms of number of parameters and memory consumption respectively) of the model. :point_right: [Back to TOC](#a-simple-simulation-framework-for-federated-learning-based-on-pytorch)

Optimizers

Click to expand! The module (folder) [optimizers](fl_sim/optimizers) contains optimizers for solving inner (local) optimization problems. Despite optimizers from `torch` and `torch_optimizers`, this module implements 1. `ProxSGD` 2. `FedPD_SGD` 3. `FedPD_VR` 4. `PSGD` 5. `PSVRG` 6. `pFedMe` 7. `FedProx` 8. `FedDR` Most of the optimizers are derived from `ProxSGD`. :point_right: [Back to TOC](#a-simple-simulation-framework-for-federated-learning-based-on-pytorch)

Regularizers

Click to expand! The module (folder) [regularizers](fl_sim/regularizers) contains code for regularizers for model parameters (weights). 1. `L1Norm` 2. `L2Norm` 3. `L2NormSquared` 4. `NullRegularizer` These regularizers are subclasses of a base class `Regularizer`, and can be obtained by passing the name of the regularizer to the function `get_regularizer`. The regularizers share common methods `eval` and `prox_eval`. :point_right: [Back to TOC](#a-simple-simulation-framework-for-federated-learning-based-on-pytorch)

Compression

Click to expand! The module (folder) [compressors](fl_sim/compressors) contains code for constructing compressors. :point_right: [Back to TOC](#a-simple-simulation-framework-for-federated-learning-based-on-pytorch)

Utils

Click to expand! The module (folder) [utils](fl_sim/utils) contains utility functions for [data downloading](fl_sim/utils/_download_data.py), [training metrics logging](fl_sim/utils/loggers.py), [experiment visualization](fl_sim/utils/viz.py), etc. - `TxTLogger`: A logger for logging training metrics to a text file, as well as printing them to the console, in a human-readable format. - ~~`CSVLogger`: A logger for logging training metrics to a CSV file. **NOT** recommended since not memory-efficient.~~ - `JsonLogger`: A logger for logging training metrics to a JSON file. Also can be saved as a YAML file. :point_right: [Back to TOC](#a-simple-simulation-framework-for-federated-learning-based-on-pytorch)

Visualization Panel

The visualization panel is a GUI for visualizing the training results of federated learning algorithms. It is based on ipywidgets and matplotlib, and can be used in Jupyter notebooks. It has the following features:

  1. Automatically search and display the log files of complete experiments in the specified directory.
  2. Automatically decode the log files and aggregate the training metrics into curves in a matplotlib figure.
  3. Support interactive operations on the figure, including zooming, font family selection, curve smoothing, etc.
  4. Support saving the figure as a PDF/SVG/PNG/JPEG/PS file.
  5. Support curves merging via tags (e.g. experiments on the algorithm FedAvg using different seeds can be merged into a single curve) into mean curves with error bounds (standard deviation, standard error of the mean, quantiles, interquartile range, etc.).

The following GIF (created using ScreenToGif) shows a demo of the visualization panel:

FL-SIM Panel Demo GIF

NOTE: to use Windows fonts on a Linux machine (e.g. Ubuntu), one can execute the following commands:

sudo apt install ttf-mscorefonts-installer
sudo fc-cache -fv

Command Line Interface

A command line interface (CLI) is provided for running multiple federated learning experiments. The only argument is the path to the configuration file (in YAML format) for the experiments. Examples of configuration files can be found in the example-configs folder. For example, in the all-alg-fedprox-femnist.yml file, we have

Click to expand! ```yaml # Example config file for fl-sim command line interface strategy: matrix: algorithm: - Ditto - FedDR - FedAvg - FedAdam - FedProx - FedPD - FedSplit - IFCA - pFedMac - pFedMe - ProxSkip - SCAFFOLD clients_sample_ratio: - 0.1 - 0.3 - 0.7 - 1.0 algorithm: name: ${{ matrix.algorithm }} server: num_clients: null clients_sample_ratio: ${{ matrix.clients_sample_ratio }} num_iters: 100 p: 0.3 # for FedPD, ProxSkip lr: 0.03 # for SCAFFOLD num_clusters: 10 # for IFCA log_dir: all-alg-fedprox-femnist client: lr: 0.03 num_epochs: 10 batch_size: null # null for default batch size scheduler: name: step # StepLR step_size: 1 gamma: 0.99 dataset: name: FedProxFEMNIST datadir: null # default dir transform: none # none for static transform (only normalization, no augmentation) model: name: cnn_femmist_tiny seed: 0 ```

The strategy (optional) section specifies the grid search strategy; the algorithm section specifies the hyperparameters of the federated learning algorithm: name is the name of the algorithm, server specifies the hyperparameters of the server, and client specifies the hyperparameters of the client; the dataset section specifies the dataset, and the model section specifies the named model (ref. the candidate_models property of the dataset classes) to be used.

Customization

One can implement custom federated learning algorithms, datasets, optimizers with corresponding registration functions.

For example, in the custom_confi.yml file, we set

where test-files/custom_alg.py and test-files/custom_dataset.py are the files containing the custom algorithm and dataset, respectively, and Custom is the name of the custom algorithm and CustomFEMNIST is the name of the custom dataset. One can run the following command to start the simulation experiment:

fl-sim test-files/custom_conf.yml

in the root directory of this repository. If algorithm.name and dataset.name were changed to absolute paths, then one can run the command from any place.

Custom Federated Learning Algorithms

In the test-files/custom_alg.py file, we implement a custom federated learning algorithm Custom via subclassing the 4 classes ServerConfig, ClientConfig, Server, and Client, and use the register_algorithm decorator to register the algorithm. For example, the ServerConfig class is defined as follows:

@register_algorithm()
@add_docstring(server_config_kw_doc, "append")
class CustomServerConfig(ServerConfig):
    ...

Custom Datasets

In the test-files/custom_dataset.py file, we implement a custom dataset CustomFEMNIST via subclassing the FEMNIST class and use the register_dataset decorator to register the dataset.

Custom Optimizers

One can implement custom optimizers via subclassing the torch.optim.Optimizer class and use the register_optimizer decorator to register the optimizer.