adap / flower

Flower: A Friendly Federated Learning Framework
https://flower.ai
Apache License 2.0
4.89k stars 842 forks source link

Saving Global Model Parameters #487

Open patselle opened 3 years ago

patselle commented 3 years ago

Hi,

I am currently trying out the flower framework under pytorch. I am very surprised how well it works. One thing is still unclear to me, after the federated-learning process is over, i would like to save the new global model parameters on the clients, after the server distribute them to all clients. How is that possible, or where to implement them, if not already done?

And why is min_fit_clients and min_eval_clients in fedavg.py set to 2 and not 1, is there a special reason?

Greetings

Patrick

danieljanes commented 3 years ago

Thanks for the feedback @patselle !

For saving the model, there are currently two ways to do it:

  1. Implement a custom strategy that saves the weights before returning them from aggregate_fit (see the example below)
  2. (hacky) save them via the eval_fn

I've put together a quick example on how to do this in the following draft PR:

https://github.com/adap/flower/pull/488/files#diff-206567616f04a829972d62974a49c3b5769e331dd544233f180182c088c18ebfR30

We'll also add a more robust (& documented) way of doing this soon, stay tuned. Regarding the min_fit_clients and min_eval_clients default: it's set to 2 because setting it to 1 doesn't really do federated learning, it just selects one random client after another. You can of course set it to 1 for testing purposes, users usually don't rely on the defaults and customize those values for their workloads.

Here's the important part of #488 - a custom Strategy implementation which saves the weights before returning them from the aggregate_fit method:

class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        rnd: int,
        results,
        failures,
    ) -> Optional[fl.common.Weights]:
        weights = super().aggregate_fit(rnd, results, failures)
        if weights is not None:
            # Save weights
            print(f"Saving round {rnd} weights...")
            np.savez(f"round-{rnd}-weights.npz", *weights)
        return weights

# [....]

strategy = SaveModelStrategy(
    fraction_fit=1.0,
    min_fit_clients=2,
    min_available_clients=2,
    eval_fn=get_eval_fn(testloader),
    on_fit_config_fn=fit_config,
)

fl.server.start_server(
    server_address=DEFAULT_SERVER_ADDRESS,
    config={"num_rounds": 3},
    strategy=strategy,
)
patselle commented 3 years ago

@danieljanes thank you very much!

danieljanes commented 3 years ago

Progress on the built-in model saving will be tracked in #357

patselle commented 3 years ago

Hi @danieljanes, sorry for writing you again, some questions and maybe suggestions.

First, I have added your code snipped and the global model is now saved on the server.

I am and maybe some other Federated Learning applications are interested in distributing the global model to the clients, as they want to benefit from each other. At the moment this is done in the mnist.py evaluate method (pytorch mnist example), so it's not that nice. I'll try to implement it via the strategy (let's see if I can do this).

At the beginning of the process the server takes the model parameters of any client. Here you might encounter the problem that different clients have different models (model versions), e.g. a client did not participate in the previous FD process. Does Flower plan to take the initialisation weights not from a client but from its own history/pool (e.g. last global model)?

I have one last question, what is the advantage of converting the weights in the form of numpy arrays into bytes, e.g. for communication between server and clients, or is there an advantage if the weights are saved in this form?

Greetings and best thanks

Patrick

danieljanes commented 3 years ago

Thanks for your questions @patselle.

Regarding initial weights: it's true that the server takes the weights of one random client, which is a workaround that accounts for the fact that we do not want to force the user to put model initialization code on the server side. However this does not mean that we shouldn't provide this functionality, I think it could be a great addition to the Strategy interface. This could also be used to continue training from a previously stopped state, so I think it makes sense to add this.

I'm afraid I can't quite follow your question regarding distributing the global model to the clients. What happens right now is the following:

  1. The server starts and gets initial model prameters from one of the connected clients, those are then used as the new global model parameters
  2. Server selects a few clients (via the Strategy) and sends those the global model parameters to those clients for training
  3. Clients receive the global model parameters, use them to update their local model. They train their local models, extract the (now updated) model parameters, and return those to the server
  4. Server waits until enough model updates are available, aggregates them, and replaces the global model parameters with the new aggregate
  5. Server evaluates the new global model parameters either via the Strategy.evaluate method (i.e., centralized evaluation) or on a sample of clients (i.e., federated evaluation, the server sends the global model parameters to a few sampled clients, they evaluate on their local data via Client.evaluate and return their results to the server)
  6. [Repeat step 2. - 5.]

So the server already sends global model parameters to the clients, both via Client.fit and Client.evaluate. Does this answer your question?

patselle commented 3 years ago

Thanks for your questions @patselle.

Regarding initial weights: it's true that the server takes the weights of one random client, which is a workaround that accounts for the fact that we do not want to force the user to put model initialization code on the server side. However this does not mean that we shouldn't provide this functionality, I think it could be a great addition to the Strategy interface. This could also be used to continue training from a previously stopped state, so I think it makes sense to add this.

I'm afraid I can't quite follow your question regarding distributing the global model to the clients. What happens right now is the following:

1. The server starts and gets initial model prameters from one of the connected clients, those are then used as the new global model parameters

2. Server selects a few clients (via the `Strategy`) and sends those the global model parameters  to those clients for training

3. Clients receive the global model parameters, use them to update their local model. They train their local models, extract the (now updated) model parameters, and return those to the server

4. Server waits until enough model updates are available, aggregates them, and replaces the global model parameters with the new aggregate

5. Server evaluates the new global model parameters either via the `Strategy.evaluate` method (i.e., centralized evaluation) or on a sample of clients (i.e., federated evaluation, the server sends the global model parameters to a few sampled clients, they evaluate on their local data via `Client.evaluate` and return their results to the server)

6. [Repeat step 2. - 5.]

So the server already sends global model parameters to the clients, both via Client.fit and Client.evaluate. Does this answer your question?

Hello @danieljanes,

what I meant is that after step 5 or 6 all clients get the new model from the server (without or with federated evaluation) and saves it locally. So these clients can use the model to make predictions on local data as an example. In other words, the client serves as a participant of Federated Learning and receives the new model to work with (not only the server).

Currently I have implemented it in a way that during federated evaluation the client saves the model at the same time so that it is available locally.

Many thanks in advance

Patrick

danieljanes commented 3 years ago

Ah, I see what you mean. There are two ways to do it via the Strategy:

  1. Save it in Client.evaluate (as you suggested) - you could even implement your Strategy in a way that always samples all clients for evaluation (so that all clients receive the global model parameters), but use the config received by Client.evaluate to tell only a subset of clients to do the actual evaluation.
  2. Always select all clients for training (via a custom Strategy.configure_fit implementation) so that every client receives the global model parameters from the previous round - and then use the config to tell each client whether it should store the weights/train/return an update, or just store the weights and return no update.

Some background details: we currently don't do this because it increases the network traffic during training/evaluation. We are thinking about ways to do it in a less "hacky" way, for example, allow the client to request the latest global model from the server whenever the client wants to update its local model.

patselle commented 3 years ago

Ah, I see what you mean. There are two ways to do it via the Strategy:

1. Save it in `Client.evaluate` (as you suggested) - you could even implement your `Strategy` in a way that always samples all clients for evaluation (so that all clients receive the global model parameters), but use the `config` received by `Client.evaluate` to tell only a subset of clients to do the actual evaluation.

2. Always select all clients for training (via a custom `Strategy.configure_fit` implementation) so that every client receives the global model parameters from the previous round - and then use the `config` to tell each client whether it should store the weights/train/return an update, or just store the weights and return no update.

Some background details: we currently don't do this because it increases the network traffic during training/evaluation. We are thinking about ways to do it in a less "hacky" way, for example, allow the client to request the latest global model from the server whenever the client wants to update its local model.

Thank you, I will try it ;)

danieljanes commented 3 years ago

Initial ideas for server-side parameter initialization: #499 @patselle, what do you think about the proposal?

patselle commented 3 years ago

That sounds really good ;)

jalilahmed commented 2 years ago

Thanks a lot in advance for flower (a really good framework with a lot of potential). I am sorry if this is not exactly related to the issue but isn't it a good idea to provide functions in Server Class which the user can change, such as,

Mahdi-s commented 6 months ago

I'm running the following save model strategy and I have noticed sometime rounds are skipped and not saved, in some cases 2-10 rounds are skipped, has anyone else ran into this issue?

`class SaveModelStrategy(fl.server.strategy.FedAvg): def aggregate_fit( self, server_round: int, results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]], failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: """Aggregate model weights using weighted average and store checkpoint"""

    # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
    aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

    if aggregated_parameters is not None:
        print(f"Saving round {server_round} aggregated_parameters...")

        # Convert `Parameters` to `List[np.ndarray]`
        aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

        # Convert `List[np.ndarray]` to PyTorch`state_dict`
        params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)
        # new_server_round = str(int(server_round) + 200)
        # Save the model
        torch.save(net.state_dict(), f"review_5c_20bs_00001_bce_50/model_round_{server_round}.pth")

    return aggregated_parameters, aggregated_metrics`
kalkite commented 6 months ago

@danieljanes ,I encountered an error while attempting to obtain global modal weights in the AggregateCustomMetricStrategy class. After completing the first round, the flower simulation crashed after I attempted to obtain the weights.

i am encountering the following error

    scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
AttributeError: 'EvaluateRes' object has no attribute 'parameters'

here is my code

class AggregateCustomMetricStrategy(fl.server.strategy.FedAvg):
    def aggregate_evaluate(
            self,
            server_round: int,
            results,
            failures,
    ):
        """Aggregate evaluation accuracy using weighted average."""

        if not results:
            return None, {}

        # Call aggregate_evaluate from base class (FedAvg) to aggregate loss and metrics
        aggregated_loss, aggregated_metrics = super().aggregate_evaluate(server_round, results, failures)

        weights, fit_res = super().aggregate_fit(server_round, results, failures)

        print("weights: ", weights)

        accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
        examples = [r.num_examples for _, r in results]

        # Aggregate and print custom metric
        aggregated_accuracy = sum(accuracies) / sum(examples)
        print(f"Round {server_round} accuracy aggregated from client results: {aggregated_accuracy}")
        # Return aggregated loss and metrics (i.e., aggregated accuracy)
        return aggregated_loss, {"accuracy": aggregated_accuracy}
     def get_evaluate_server_fn(model, test_loader):
    def evaluate_fn(server_round, parameters, config):
        params_dict = zip(model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict, strict=True)
        loss, accuracy = fed_test(model=model,
                                  test_loader=test_loader)
        return loss, {"accuracy": accuracy}
    return evaluate_fn