Open Lando-L opened 1 year ago
Thanks for using FedJAX and filing this issue! Can you please give an example of what you would like to store as part of client states and give the link to the TensorflowFederated example that you are interested in.
I would like the clients to store local model parameters which are used during training but are not send to the server. This would be usefull to implement personalisation strategies such as Federated Learning with Personalization Layers or Adaptive Personalized Federated Learning. The TensorflowFederated example can be found here: https://github.com/tensorflow/federated/tree/main/tensorflow_federated/examples/stateful_clients.
We just checked in an example of federated averaging with stateful clients. Similar to the example you linked, it keeps track of number of steps per client across rounds.
Some important notes about this example:
fedjax.for_each_client
which comes with some restrictions (e.g. every client state must be a PyTree with the same structure b/c of restrictions around jax.jit
and jax.pmap
). Implementing this without fedjax.for_each_client
should be very straightforward since it can be replaced with a native for-loop. Generally, if you expect to be doing very different and custom routines for each client, I would suggest not using fedjax.for_each_client
but just writing a for loop (like in our simple fed_avg example). client_states
is kept as part of ServerState
since we want to be make sure ServerState
encapsulates all the information at a given round (e.g. for checkpointing to restart failed experiments)client_states
in memory in this example, but it is not recommended to do this when the number of clients or the size of each client's state is very large. This is especially true if you're using accelerators and storing model parameters as client state, since model params are jax device arrays that live on device.Thank you for adding the example. I'll try to implement the two personalized federating learning algorithms using this as a template. Would these algorithms be of interested as potential contributions to the fedjax library?
Thanks for offering to contribute! We are happy to accept contributions on published algorithms. For any algorithm, please consider adding relevant tests and experimental scripts to reproduce numbers on at least one dataset e.g., https://github.com/google/fedjax/blob/main/experiments/fed_avg/run_fed_avg.py
I am happy to contribute. I have already implemented the algorithms and written test analogous to the ones for Federated Averaging. Unfortunately, writing the experimental scripts will require a few more changes, e.g. in https://github.com/google/fedjax/blob/main/fedjax/training/federated_experiment.py. The evaluation of the personalised models can only be performed per client, in contrast to Federated Averaging which can be evaluated using a single centralised model. I can also offer my contribution here, but this might be better solved on a separate pull request.
Rather than modify fedjax.training.federated_experiment, I would suggest writing your own experiment loop since your experiment set up is pretty unique and federated_experiment is pretty straightforward. You should be able to reuse a lot of the underlying modules like fedjax.core.client_samplers and fedjax.training.checkpoint.
You could start by just adding a directory fedjax/experiments/YOUR_ALGO and putting your experiment scripts there for at least one dataset. But I do think that should be added along with the algo implementation so we can verify it trains properly and reproduces paper results etc.
I opened a pull request #297 for the Adaptive Personalized Federated Learning Algorithm. I could add the other algorithm in a similar fashion after getting your feedback on this PR.
At this moment I don't see how to implement a
fedjax.FederatedAlgorithm
with stateful clients. Which would be necessary to implement personalised federated algorithms. It would be great to include an example similar to the one e.g. in TensorflowFederated.