omarfoq / FedEM

Official code for "Federated Multi-Task Learning under a Mixture of Distributions" (NeurIPS'21)
Apache License 2.0
154 stars 28 forks source link
deep-learning federated-learning machine-learning personalized-federated-learning pytorch

Federated Multi-Task Learning under a Mixture of Distributions

This repository is the official implementation of Federated Multi-Task Learning under a Mixture of Distributions.

The increasing size of data generated by smartphones and IoT devices motivated the development of Federated Learning (FL), a framework for on-device collaborative training of machine learning models. First efforts in FL focused on learning a single global model with good average performance across clients, but the global model may be arbitrarily bad for a given client, due to the inherent heterogeneity of local data distributions. Federated multi-task learning (MTL) approaches can learn personalized models by formulating an opportune penalized optimization problem. The penalization term can capture complex relations among personalized models, but eschews clear statistical assumptions about local data distributions. In this work, we propose to study federated MTL under the flexible assumption that each local data distribution is a mixture of unknown underlying distributions.

This assumption encompasses most of the existing personalized FL approaches and leads to federated EM-like algorithms for both client-server and fully decentralized settings. Moreover, it provides a principled way to serve personalized models to clients not seen at training time. The algorithms' convergence is analyzed through a novel federated surrogate optimization framework, which can be of general interest. Experimental results on FL benchmarks show that in most cases our approach provides models with higher accuracy and fairness than state-of-the-art methods.

Requirements

To install requirements:

pip install -r requirements.txt

Usage

We provide code to simulate federated training of machine learning. The core objects are Aggregator and Client, different federated learning algorithms can be implemented by revising the local update method Client.step() and/or the aggregation protocol defined in Aggregator.mix() and Aggregator.update_client().

In addition to the trivial baseline consisting of training models locally without any collaboration, this repository supports the following federated learning algorithms:

Datasets

We provide five federated benchmark datasets spanning a wide range of machine learning tasks: image classification (CIFAR10 and CIFAR100), handwritten character recognition (EMNIST and FEMNIST), and language modelling (Shakespeare), in addition to a synthetic dataset

Shakespeare dataset (resp. FEMNIST) was naturally partitioned by assigning all lines from the same characters (resp. all images from the same writer) to the same client. We created federated versions of CIFAR10 and EMNIST by distributing samples with the same label across the clients according to a symmetric Dirichlet distribution with parameter 0.4. For CIFAR100, we exploited the availability of "coarse" and "fine" labels, using a two-stage Pachinko allocation method to assign 600 sample to each of the 100 clients.

The following table summarizes the datasets and models

Dataset Task Model
FEMNIST Handwritten character recognition 2-layer CNN + 2-layer FFN
EMNIST Handwritten character recognition 2-layer CNN + 2-layer FFN
CIFAR10 Image classification MobileNet-v2
CIFAR100 Image classification MobileNet-v2
Shakespeare Next character prediction Stacked LSTM
Synthetic dataset Binary classification Linear model

See the README.md files of respective dataset, i.e., data/$DATASET, for instructions on generating data

Training

Run on one dataset, with a specific choice of federated learning method. Specify the name of the dataset (experiment), the used method, and configure all other hyper-parameters (see all hyper-parameters values in the appendix of the paper)

python3  python run_experiment.py cifar10 FedAvg \
    --n_learners 1 \
    --n_rounds 200 \
    --bz 128 \
    --lr 0.01 \
    --lr_scheduler multi_step \
    --log_freq 5 \
    --device cuda \
    --optimizer sgd \
    --seed 1234 \
    --logs_root ./logs \
    --verbose 1

The test and training accuracy and loss will be saved in the specified log path.

We provide example scripts to run paper experiments under scripts/ directory.

Evaluation

We give instructions to run experiments on CIFAR-10 dataset as an example (the same holds for the other datasets). You need first to go to ./data/cifar10, follow the instructions in README.md to download and partition the dataset.

All experiments will generate tensorboard log files (logs/cifar10) that you can interact with, using TensorBoard

Average performance of personalized models

Run the following scripts, this will generate tensorboard logs that you can interact with to make plots or get the values presented in Table 2

# run FedAvg
echo "Run FedAvg"
python run_experiment.py cifar10 FedAvg --n_learners 1 --n_rounds 200 --bz 128 --lr 0.01 \
 --lr_scheduler multi_step --log_freq 5 --device cuda --optimizer sgd --seed 1234 --verbose 1

# run FedAvg + local adaption
echo "run FedAvg + local adaption"
python run_experiment.py cifar10 FedAvg --n_learners 1 --locally_tune_clients --n_rounds 201 --bz 128 \
 --lr 0.001 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1

# run training using local data only
echo "Run Local"
python run_experiment.py cifar10 local --n_learners 1 --n_rounds 201 --bz 128 --lr 0.03 \
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1

# run Clustered FL
echo "Run Clustered FL"
python run_experiment.py cifar10 clustered --n_learners 1 --n_rounds 201 --bz 128 --lr 0.003 \
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1

# run FedProx
echo "Run FedProx"
python run_experiment.py cifar10 FedProx --n_learners 1 --n_rounds 201 --bz 128 --lr 0.01 --mu 1.0\
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer prox_sgd --seed 1234 --verbose 1

# Run pFedME
echo "Run "
python run_experiment.py cifar10 pFedMe --n_learners 1 --n_rounds 201 --bz 128 --lr 0.001 --mu 1.0 \
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer prox_sgd --seed 1234 --verbose 1

# run FedEM
echo "Run FedEM"
python run_experiment.py cifar10 FedEM --n_learners 3 --n_rounds 201 --bz 128 --lr 0.03 \
 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1

Similar for other datasets are provided in papers_experiments/

Effect of M

Run the following scripts to get the logs corresponding to different choices for the parameter M (here we ive an example with M=4)

# run FedEM

echo "Run FedEM | M=4"
python run_experiment.py cifar10 FedEM \
    --n_learners 4 \
    --n_rounds 201 \
    --bz 128 --lr 0.03 \
    --lr_scheduler multi_step \
    --log_freq 10  \
    --device cuda \
    --optimizer sgd \
    --logs_root logs/cifar10/FedEM_4 \
    --seed 1234 \
    --verbose 1

Generalization to unseen clients

You need to run the same script as in the previous section. Make sure that --test-clients-frac is non-zero, when you call generate_data.py.

Clients sampling

Our code gives the possibility to use only a fraction of the available clients at each round, you can specify this parameter when running run_experiment.py using the argument --sampling_rate (default is 0).

Fully-decentralized federated learning

To simulate a fully-decentralized training you need to specify --decentralized when you run run_experiment.py

Results

The performance of each personalized model (which is the same for all clients in the case of FedAvg and FedProx) is evaluated on the local test dataset (unseen at training). The following shows the average weighted accuracy with weights proportional to local dataset sizes. We observe that FedEM obtains the best performance across all datasets.

Dataset Local FedAvg FedAvg+ FedEM (Ours)
FEMNIST 71.0 78.6 75.3 79.9
EMNIST 71.9 82.6 83.1 83.5
CIFAR10 70.2 78.2 82.3 84.3
CIFAR100 31.5 40.9 39.0 44.1
Shakespeare 32.0 46.7 40.0 43.7

We can also visualise the evolution of the train loss, train accuracy, test loss and test accuracy for CIFAR-10 dataset

Similar plots can be built for other experiments using the make_plot function in utils/plots.py

Citation

If you use our code or wish to refer to our results, please use the following BibTex entry:

@article{marfoq2021federated,
  title={Federated multi-task learning under a mixture of distributions},
  author={Marfoq, Othmane and Neglia, Giovanni and Bellet, Aur{\'e}lien and Kameni, Laetitia and Vidal, Richard},
  journal={Advances in Neural Information Processing Systems},
  volume={34},
  pages={15434--15447},
  year={2021}
}