google / fedjax

FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.
Apache License 2.0
252 stars 42 forks source link

Add support for stateful clients #289

Open Lando-L opened 1 year ago

Lando-L commented 1 year ago

At this moment I don't see how to implement a fedjax.FederatedAlgorithm with stateful clients. Which would be necessary to implement personalised federated algorithms. It would be great to include an example similar to the one e.g. in TensorflowFederated.

stheertha commented 1 year ago

Thanks for using FedJAX and filing this issue! Can you please give an example of what you would like to store as part of client states and give the link to the TensorflowFederated example that you are interested in.

Lando-L commented 1 year ago

I would like the clients to store local model parameters which are used during training but are not send to the server. This would be usefull to implement personalisation strategies such as Federated Learning with Personalization Layers or Adaptive Personalized Federated Learning. The TensorflowFederated example can be found here: https://github.com/tensorflow/federated/tree/main/tensorflow_federated/examples/stateful_clients.

jaehunro commented 1 year ago

We just checked in an example of federated averaging with stateful clients. Similar to the example you linked, it keeps track of number of steps per client across rounds.

Some important notes about this example:

  1. It's implemented using fedjax.for_each_client which comes with some restrictions (e.g. every client state must be a PyTree with the same structure b/c of restrictions around jax.jit and jax.pmap). Implementing this without fedjax.for_each_client should be very straightforward since it can be replaced with a native for-loop. Generally, if you expect to be doing very different and custom routines for each client, I would suggest not using fedjax.for_each_client but just writing a for loop (like in our simple fed_avg example).
  2. client_states is kept as part of ServerState since we want to be make sure ServerState encapsulates all the information at a given round (e.g. for checkpointing to restart failed experiments)
  3. We're storing client_states in memory in this example, but it is not recommended to do this when the number of clients or the size of each client's state is very large. This is especially true if you're using accelerators and storing model parameters as client state, since model params are jax device arrays that live on device.
Lando-L commented 1 year ago

Thank you for adding the example. I'll try to implement the two personalized federating learning algorithms using this as a template. Would these algorithms be of interested as potential contributions to the fedjax library?

stheertha commented 1 year ago

Thanks for offering to contribute! We are happy to accept contributions on published algorithms. For any algorithm, please consider adding relevant tests and experimental scripts to reproduce numbers on at least one dataset e.g., https://github.com/google/fedjax/blob/main/experiments/fed_avg/run_fed_avg.py

Lando-L commented 1 year ago

I am happy to contribute. I have already implemented the algorithms and written test analogous to the ones for Federated Averaging. Unfortunately, writing the experimental scripts will require a few more changes, e.g. in https://github.com/google/fedjax/blob/main/fedjax/training/federated_experiment.py. The evaluation of the personalised models can only be performed per client, in contrast to Federated Averaging which can be evaluated using a single centralised model. I can also offer my contribution here, but this might be better solved on a separate pull request.

jaehunro commented 1 year ago

Rather than modify fedjax.training.federated_experiment, I would suggest writing your own experiment loop since your experiment set up is pretty unique and federated_experiment is pretty straightforward. You should be able to reuse a lot of the underlying modules like fedjax.core.client_samplers and fedjax.training.checkpoint.

You could start by just adding a directory fedjax/experiments/YOUR_ALGO and putting your experiment scripts there for at least one dataset. But I do think that should be added along with the algo implementation so we can verify it trains properly and reproduces paper results etc.

Lando-L commented 1 year ago

I opened a pull request #297 for the Adaptive Personalized Federated Learning Algorithm. I could add the other algorithm in a similar fashion after getting your feedback on this PR.