adap / flower

Flower: A Friendly Federated AI Framework
https://flower.ai
Apache License 2.0
5.01k stars 860 forks source link

FedOpt strategies: Central model gives NaN after second aggregation #1060

Open sancarlim opened 2 years ago

sancarlim commented 2 years ago

Describe the bug

When I use any of the FedOpt strategies (FedAdam, FedYogi, FedAdagrad) it seems very unstable, the model outputs NaNs after second/third aggregation.

Steps/Code to Reproduce

Strategy:

   strategy = fl.server.strategy.FedYogi(
        fraction_fit = fc/ac,
        fraction_eval = 0.2, # not used - no federated evaluation
        min_fit_clients = fc,
        min_eval_clients = 2, # not used 
        min_available_clients = ac,
        eval_fn=get_eval_fn(model),
        on_fit_config_fn=fit_config,
        on_evaluate_config_fn=evaluate_config,
        initial_parameters=fl.common.weights_to_parameters(model_weights), 
    ) 

def get_eval_fn(model):
    """Return an evaluation function for server-side evaluation."""
    _, testset, _ = utils.load_isic_by_patient_server() 
    testloader = DataLoader(testset, batch_size=16, num_workers=4, worker_init_fn=utils.seed_worker, shuffle = False) 

    # The `evaluate` function will be called after every round
    def evaluate(
        weights: fl.common.Weights,
    ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
        # Update model with the latest parameters
        set_parameters(model, weights) 
        loss, auc, accuracy, f1 = utils.val(model, testloader, criterion = nn.BCEWithLogitsLoss()) 

        return float(loss), {"accuracy": float(accuracy), "auc": float(auc)}

    return evaluate

def fit_config(rnd: int):
    """Return training configuration dict for each round.
    Keep batch size fixed at 32, perform two rounds of training with one
    local epoch, increase to two local epochs afterwards.
    """
    config = {
        "batch_size": 32,
        "local_epochs": 1 if rnd < 2 else 2,
    }
    return config

def evaluate_config(rnd: int):
    """Return evaluation configuration dict for each round.
    Perform five local evaluation steps on each client (i.e., use five
    batches) during rounds one to three, then increase to ten local
    evaluation steps.
    """
    val_steps = 5 if rnd < 4 else 10
    return {"val_steps": val_steps}

I have tried with the default eval_fn and defaults fit and evaluate configs, and the behavior changes but still gives NaNs in the end.

Dataset: ISIC 2020 https://www.kaggle.com/c/siim-isic-melanoma-classification/data I have tested with fc=ac=3 and fc=ac=2. In the latter, one of the clients has a training set of ~2k images and the other ~10.5k. Model: EfficientNetB2.

Code: https://github.com/sandracl72/flower server_advanced.py (--nowandb) client_isic.py --partition 0 (--nowandb) client_isic.py --partition 1 (--nowandb)

Expected Results

The server aggregated the weights of all clients.

Actual Results

INFO flower 2022-02-11 09:28:34,874 | app.py:109 | Flower server running (10 rounds) SSL is disabled INFO flower 2022-02-11 09:28:34,875 | server.py:118 | Initializing global parameters INFO flower 2022-02-11 09:28:34,875 | server.py:301 | Using initial parameters provided by strategy INFO flower 2022-02-11 09:28:34,875 | server.py:120 | Evaluating initial parameters INFO flower 2022-02-11 09:28:54,419 | server.py:123 | initial parameters (loss, other metrics): 0.6916685566558676, {'accuracy': 0.5218716861081655, 'auc': 0.48442367381213036} INFO flower 2022-02-11 09:28:54,419 | server.py:133 | FL starting DEBUG flower 2022-02-11 09:28:54,419 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2) DEBUG flower 2022-02-11 09:30:50,949 | server.py:261 | fit_round received 2 results and 0 failures INFO flower 2022-02-11 09:31:09,522 | server.py:148 | fit progress: (1, 0.32495464879074293, {'accuracy': 0.901643690349947, 'auc': 0.8418079867528361}, 135.10243083909154) INFO flower 2022-02-11 09:31:09,522 | server.py:199 | evaluate_round: no clients selected, cancel DEBUG flower 2022-02-11 09:31:09,522 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2) DEBUG flower 2022-02-11 09:33:03,801 | server.py:261 | fit_round received 2 results and 0 failures Traceback (most recent call last): File "server_advanced.py", line 130, in fl.server.start_server("0.0.0.0:8080", config={"num_rounds": rounds}, strategy=strategy) File "/workspace/flower/src/py/flwr/server/app.py", line 111, in start_server hist = _fl( File "/workspace/flower/src/py/flwr/server/app.py", line 148, in _fl hist = server.fit(num_rounds=config["num_rounds"]) File "/workspace/flower/src/py/flwr/server/server.py", line 145, in fit res_cen = self.strategy.evaluate(parameters=self.parameters) File "/workspace/flower/src/py/flwr/server/strategy/fedavg.py", line 178, in evaluate eval_res = self.eval_fn(weights) File "server_advanced.py", line 47, in evaluate loss, auc, accuracy, f1 = utils.val(model, testloader, criterion = nn.BCEWithLogitsLoss()) File "/workspace/flower/utils.py", line 468, in val val_accuracy = accuracy_score(val_gt2, torch.round(pred2)) File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/validation.py", line 72, in inner_f return f(**kwargs) File "/opt/conda/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 187, in accuracy_score y_type, y_true, y_pred = _check_targets(y_true, y_pred) File "/opt/conda/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 83, in _check_targets type_pred = type_of_target(y_pred) File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/multiclass.py", line 287, in type_of_target _assert_all_finite(y) File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/validation.py", line 96, in _assert_all_finite raise ValueError( ValueError: Input contains NaN, infinity or a value too large for dtype('float32').

pedropgusmao commented 2 years ago

Hi @sandracl72 , Thanks for this. I'm guessing you have already tried with FedAvg and it works, right? Could you run a few tests for me, please? In your FedYogi, could you please print the outputs of the following:

sancarlim commented 2 years ago

Thank you for your answer @pedropgusmao . Yes, I've performed several experiments using FedAvg and never had this issue before.

Here you can access the output log: https://drive.google.com/file/d/1TWg7EgCHFbDMID7SaW78B1vNnfemWyjZ/view?usp=sharing

Thanks!

pedropgusmao commented 2 years ago

Hi @sandracl72, thanks for this. From the log file I don't see a NaN on the prints, but only in; File "/workspace/flower/src/py/flwr/server/strategy/fedavg.py", line 178, in evaluate. Could you try and remove the previous prints and just print parameters right at the beginning of this function instead please? https://github.com/adap/flower/blob/5fb4f9c1cd0070495049f45f382407e9d95166cd/src/py/flwr/server/strategy/fedavg.py#L173

sancarlim commented 2 years ago

Hi @pedropgusmao , I printed sum([np.isnan(w).sum() for w in weights ]) before the evaluation to detect if the parameters being evaluated have NaN, but apparently they haven't. This is the output log:

INFO flower 2022-02-16 08:30:10,677 | app.py:109 | Flower server running (10 rounds)
SSL is disabled
INFO flower 2022-02-16 08:30:10,678 | server.py:118 | Initializing global parameters
INFO flower 2022-02-16 08:30:10,678 | server.py:301 | Using initial parameters provided by strategy
INFO flower 2022-02-16 08:30:10,678 | server.py:120 | Evaluating initial parameters
INFO flower 2022-02-16 08:30:10,678 | fedavg.py:175 | Evaluate model parameters using eval fcn
INFO flower 2022-02-16 08:30:10,797 | fedavg.py:182 | Number of weights with NaN value: 0
INFO flower 2022-02-16 08:30:30,064 | server.py:123 | initial parameters (loss, other metrics): 0.6916685566558676, {'accuracy': 0.5218716861081655, 'auc': 0.48442367381213036}
INFO flower 2022-02-16 08:30:30,064 | server.py:133 | FL starting
DEBUG flower 2022-02-16 08:32:52,575 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-02-16 08:33:44,737 | server.py:261 | fit_round received 2 results and 0 failures
INFO flower 2022-02-16 08:33:45,329 | fedavg.py:175 | Evaluate model parameters using eval fcn
INFO flower 2022-02-16 08:33:45,438 | fedavg.py:182 | Number of weights with NaN value: 0
INFO flower 2022-02-16 08:34:03,231 | server.py:148 | fit progress: (1, 0.3569583234853719, {'accuracy': 0.8939554612937434, 'auc': 0.7939201167759478}, 213.1666322145611)
INFO flower 2022-02-16 08:34:03,232 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-02-16 08:34:03,233 | server.py:252 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-02-16 08:34:52,459 | server.py:261 | fit_round received 2 results and 0 failures
INFO flower 2022-02-16 08:34:53,040 | fedavg.py:175 | Evaluate model parameters using eval fcn
INFO flower 2022-02-16 08:34:53,155 | fedavg.py:182 | Number of weights with NaN value: 0
Traceback (most recent call last):
  File "server_advanced.py", line 135, in <module>
    fl.server.start_server("0.0.0.0:8080", config={"num_rounds": rounds}, strategy=strategy)
  File "/workspace/flower/src/py/flwr/server/app.py", line 111, in start_server
    hist = _fl(
  File "/workspace/flower/src/py/flwr/server/app.py", line 148, in _fl
    hist = server.fit(num_rounds=config["num_rounds"])
  File "/workspace/flower/src/py/flwr/server/server.py", line 145, in fit
    res_cen = self.strategy.evaluate(parameters=self.parameters)
  File "/workspace/flower/src/py/flwr/server/strategy/fedavg.py", line 183, in evaluate
    eval_res = self.eval_fn(weights)
  File "server_advanced.py", line 51, in evaluate
    loss, auc, accuracy, f1 = utils.val(model, testloader, criterion = nn.BCEWithLogitsLoss())
  File "/workspace/flower/utils.py", line 468, in val
    val_accuracy = accuracy_score(val_gt2, torch.round(pred2))
  File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/validation.py", line 72, in inner_f
    return f(**kwargs)
  File "/opt/conda/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 187, in accuracy_score
    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
  File "/opt/conda/lib/python3.8/site-packages/sklearn/metrics/_classification.py", line 83, in _check_targets
    type_pred = type_of_target(y_pred)
  File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/multiclass.py", line 287, in type_of_target
    _assert_all_finite(y)
  File "/opt/conda/lib/python3.8/site-packages/sklearn/utils/validation.py", line 96, in _assert_all_finite
    raise ValueError(
ValueError: Input contains NaN, infinity or a value too large for dtype('float32').

So it seems the NaNs appear in the forward pass after the second aggregation.

Here you can see the histograms of the weights and gradients of the aggregated model: https://wandb.ai/eyeforai/dai-healthcare/reports/FedOpt-Debug--VmlldzoxNTY4ODQ1?accessToken=4kqdpvadvojfpd8my9iflqsw0zk7d9d0xativ7eu5ad69t17ovi7fm91v2co0oy4

pedropgusmao commented 2 years ago

Hi @sandracl72 , I might be wrong, but could it be that evaluation is performing something weird? From the log I see INFO flower 2022-02-16 08:34:53,155 | fedavg.py:182 | Number of weights with NaN value: 0 right before evaluation actually starts. Could you also use the same isnan method to print inside the https://github.com/adap/flower/blob/d80c8c2738b79badbc820f940efcd7fb4fff9503/src/py/flwr/server/strategy/fedyogi.py#L161 , please ? Thanks

pedropgusmao commented 2 years ago

@danieljanes this might be related to BatchNorm. Do you remember if we had any example with aggregation and batch norm?

sancarlim commented 2 years ago

@danieljanes the weights aren't NaN before or after the aggregation, but after the forward pass (in the second evaluation round) the results are NaN. We thought it might be related to the BN, and now I have found that if I don't aggregate the "bn" weights, it works fine.

Changes:

def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Set model parameters from a list of NumPy ndarrays
        keys = [k for k in self.model.state_dict().keys() if 'bn' not in k]
        params_dict = zip(keys, parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=False)

def get_parameters(self) -> List[np.ndarray]:  
        return [val.cpu().numpy() for name, val in self.model.state_dict().items() if 'bn' not in name]
pedropgusmao commented 2 years ago

This is being investigated, but for now, I'd recommend:

danieljanes commented 2 years ago

Now that the 0.18 release is done there's more room to investigate this. @sandracl72 , would it be possible to have a repo that reproduces the error in a minimalistic way with Flower 0.18? If @pedropgusmao doesn't already have one, that is.

sancarlim commented 2 years ago

Sure, here you have a minimal example using Flower 0.18, with only 50 images: https://github.com/sandracl72/flower_fedopt_debug.git

You have to run it with --nowandb arg. I've left it in case you want to log some metrics in your own project. with wandb.watch(model, log="all") you can track the weights and grads, which could be useful for debugging. You can use directly run.sh.

Thanks !