This repository is the official implementation of Federated Multi-Task Learning under a Mixture of Distributions.
The increasing size of data generated by smartphones and IoT devices motivated the development of Federated Learning (FL), a framework for on-device collaborative training of machine learning models. First efforts in FL focused on learning a single global model with good average performance across clients, but the global model may be arbitrarily bad for a given client, due to the inherent heterogeneity of local data distributions. Federated multi-task learning (MTL) approaches can learn personalized models by formulating an opportune penalized optimization problem. The penalization term can capture complex relations among personalized models, but eschews clear statistical assumptions about local data distributions. In this work, we propose to study federated MTL under the flexible assumption that each local data distribution is a mixture of unknown underlying distributions.
This assumption encompasses most of the existing personalized FL approaches and leads to federated EM-like algorithms for both client-server and fully decentralized settings. Moreover, it provides a principled way to serve personalized models to clients not seen at training time. The algorithms' convergence is analyzed through a novel federated surrogate optimization framework, which can be of general interest. Experimental results on FL benchmarks show that in most cases our approach provides models with higher accuracy and fairness than state-of-the-art methods.
To install requirements:
pip install -r requirements.txt
We provide code to simulate federated training of machine learning.
The core objects are Aggregator
and Client
, different federated learning
algorithms can be implemented by revising the local update method
Client.step()
and/or the aggregation protocol defined in
Aggregator.mix()
and Aggregator.update_client()
.
In addition to the trivial baseline consisting of training models locally without any collaboration, this repository supports the following federated learning algorithms:
We provide five federated benchmark datasets spanning a wide range of machine learning tasks: image classification (CIFAR10 and CIFAR100), handwritten character recognition (EMNIST and FEMNIST), and language modelling (Shakespeare), in addition to a synthetic dataset
Shakespeare dataset (resp. FEMNIST) was naturally partitioned by assigning all lines from the same characters (resp. all images from the same writer) to the same client. We created federated versions of CIFAR10 and EMNIST by distributing samples with the same label across the clients according to a symmetric Dirichlet distribution with parameter 0.4. For CIFAR100, we exploited the availability of "coarse" and "fine" labels, using a two-stage Pachinko allocation method to assign 600 sample to each of the 100 clients.
The following table summarizes the datasets and models
Dataset | Task | Model |
---|---|---|
FEMNIST | Handwritten character recognition | 2-layer CNN + 2-layer FFN |
EMNIST | Handwritten character recognition | 2-layer CNN + 2-layer FFN |
CIFAR10 | Image classification | MobileNet-v2 |
CIFAR100 | Image classification | MobileNet-v2 |
Shakespeare | Next character prediction | Stacked LSTM |
Synthetic dataset | Binary classification | Linear model |
See the README.md
files of respective dataset, i.e., data/$DATASET
,
for instructions on generating data
Run on one dataset, with a specific choice of federated learning method. Specify the name of the dataset (experiment), the used method, and configure all other hyper-parameters (see all hyper-parameters values in the appendix of the paper)
python3 python run_experiment.py cifar10 FedAvg \
--n_learners 1 \
--n_rounds 200 \
--bz 128 \
--lr 0.01 \
--lr_scheduler multi_step \
--log_freq 5 \
--device cuda \
--optimizer sgd \
--seed 1234 \
--logs_root ./logs \
--verbose 1
The test and training accuracy and loss will be saved in the specified log path.
We provide example scripts to run paper experiments under scripts/
directory.
We give instructions to run experiments on CIFAR-10 dataset as an example
(the same holds for the other datasets). You need first to go to
./data/cifar10
, follow the instructions in README.md
to download and partition
the dataset.
All experiments will generate tensorboard log files (logs/cifar10
) that you can
interact with, using TensorBoard
Run the following scripts, this will generate tensorboard logs that you can interact with to make plots or get the values presented in Table 2
# run FedAvg
echo "Run FedAvg"
python run_experiment.py cifar10 FedAvg --n_learners 1 --n_rounds 200 --bz 128 --lr 0.01 \
--lr_scheduler multi_step --log_freq 5 --device cuda --optimizer sgd --seed 1234 --verbose 1
# run FedAvg + local adaption
echo "run FedAvg + local adaption"
python run_experiment.py cifar10 FedAvg --n_learners 1 --locally_tune_clients --n_rounds 201 --bz 128 \
--lr 0.001 --lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1
# run training using local data only
echo "Run Local"
python run_experiment.py cifar10 local --n_learners 1 --n_rounds 201 --bz 128 --lr 0.03 \
--lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1
# run Clustered FL
echo "Run Clustered FL"
python run_experiment.py cifar10 clustered --n_learners 1 --n_rounds 201 --bz 128 --lr 0.003 \
--lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1
# run FedProx
echo "Run FedProx"
python run_experiment.py cifar10 FedProx --n_learners 1 --n_rounds 201 --bz 128 --lr 0.01 --mu 1.0\
--lr_scheduler multi_step --log_freq 10 --device cuda --optimizer prox_sgd --seed 1234 --verbose 1
# Run pFedME
echo "Run "
python run_experiment.py cifar10 pFedMe --n_learners 1 --n_rounds 201 --bz 128 --lr 0.001 --mu 1.0 \
--lr_scheduler multi_step --log_freq 10 --device cuda --optimizer prox_sgd --seed 1234 --verbose 1
# run FedEM
echo "Run FedEM"
python run_experiment.py cifar10 FedEM --n_learners 3 --n_rounds 201 --bz 128 --lr 0.03 \
--lr_scheduler multi_step --log_freq 10 --device cuda --optimizer sgd --seed 1234 --verbose 1
Similar for other datasets are provided in papers_experiments/
Run the following scripts to get the logs corresponding to different choices for the parameter M (here we ive an example with M=4)
# run FedEM
echo "Run FedEM | M=4"
python run_experiment.py cifar10 FedEM \
--n_learners 4 \
--n_rounds 201 \
--bz 128 --lr 0.03 \
--lr_scheduler multi_step \
--log_freq 10 \
--device cuda \
--optimizer sgd \
--logs_root logs/cifar10/FedEM_4 \
--seed 1234 \
--verbose 1
You need to run the same script as in the previous section. Make sure that --test-clients-frac
is non-zero,
when you call generate_data.py
.
Our code gives the possibility to use only a fraction of the available clients at each round,
you can specify this parameter when running run_experiment.py
using the argument --sampling_rate
(default is 0
).
To simulate a fully-decentralized training you need to specify --decentralized
when you run run_experiment.py
The performance of each personalized model (which is the same for all clients in the case of FedAvg and FedProx) is evaluated on the local test dataset (unseen at training). The following shows the average weighted accuracy with weights proportional to local dataset sizes. We observe that FedEM obtains the best performance across all datasets.
Dataset | Local | FedAvg | FedAvg+ | FedEM (Ours) |
---|---|---|---|---|
FEMNIST | 71.0 | 78.6 | 75.3 | 79.9 |
EMNIST | 71.9 | 82.6 | 83.1 | 83.5 |
CIFAR10 | 70.2 | 78.2 | 82.3 | 84.3 |
CIFAR100 | 31.5 | 40.9 | 39.0 | 44.1 |
Shakespeare | 32.0 | 46.7 | 40.0 | 43.7 |
We can also visualise the evolution of the train loss, train accuracy, test loss and test accuracy for CIFAR-10 dataset
Similar plots can be built for other experiments using the make_plot
function in utils/plots.py
If you use our code or wish to refer to our results, please use the following BibTex entry:
@article{marfoq2021federated,
title={Federated multi-task learning under a mixture of distributions},
author={Marfoq, Othmane and Neglia, Giovanni and Bellet, Aur{\'e}lien and Kameni, Laetitia and Vidal, Richard},
journal={Advances in Neural Information Processing Systems},
volume={34},
pages={15434--15447},
year={2021}
}