NVIDIA / NVFlare

NVIDIA Federated Learning Application Runtime Environment
https://nvidia.github.io/NVFlare/
Apache License 2.0
648 stars 181 forks source link

[BUG] FedOpt algorithm not working as expected in cifar10 example #1718

Closed LeandroDiL closed 1 year ago

LeandroDiL commented 1 year ago

Describe the bug The FedOpt algorithm is not working as expected in cifar10 example when I change the model from the pre-existing ModerateCNN to another model like MobileNetv2 or Resnet18 and others. The problem is that the accuracy of the global model is not increasing or increasing too slow with the FedOpt algorithm while the other algorithms works just fine even changing the model.

To Reproduce

  1. Add in 'cifar10_nets.py' the new model : class MyModel(nn.Module): def init(self): super(MyModel, self).init() model = models.mobilenet_v2(weights='DEFAULT') model.classifier = nn.Sequential( nn.Dropout(0.4), nn.Linear(1280, 10), ) self.model = model

    def forward(self, x): return self.model(x)

  2. Import and change the model in file 'cifar10_learner.py'
  3. Launch the example with ./run_simulator.sh cifar10_fedopt 0.1 8 8
  4. See the results in tensorboard with tensorboard --logdir=/tmp/nvflare/sim_cifar10 under the section 'val_acc_global_model'

Expected behavior I expect reading the algorithm proposed in Reddi, Sashank, et al. "Adaptive federated optimization." arXiv preprint arXiv:2003.00295 (2020), to obtain the same performance of FedAvg using SGD optimizer with lr = 1.0 and no scheduler. Also obtain better results changing optimizer and adding a scheduler.

Screenshots Screenshot from 2023-04-28 10-09-15 Purple = FedAvg Pink = FedOpt

Desktop (please complete the following information):

Ty in advance!

YuanTingHsieh commented 1 year ago

Thank you for trying out and raising the issue!

It would be nice if you can share your other experiments figure to benefit other people.

@holgerroth can you help answer this question, thanks

LeandroDiL commented 1 year ago

This is a graph from Tensorboard containing also the other experiments : Screenshot from 2023-05-01 09-33-08 Pink = Scaffold Dark Grey = FedProx Yellow = FedAvg Purple = FedOpt All the experiments has been done using the same model and configuration. They did 20 rounds of FL and 4 local epochs for each of the 4 clients involved by each experiment. The FedOpt experiment is worst than the other posted by me before due to a different scheduler.

Ty for the support!

holgerroth commented 1 year ago

Interesting. Just to confirm, are you using momentum on the server when using FedOpt (see here). That could explain a different behavior to FedAvg.

LeandroDiL commented 1 year ago

Ye @holgerroth, I tried using momentum with different values and I also tried to don't use it. Even if the results were changing and I was obtaining better results with some values compared to others, they were still bad results like I reached at max 0.5 acc that is pretty low compared with the other algorithms. I also noticed that with other models like the SimpleCNN or other models built from scratch it's working fine, problems come when using pretrained CNN. Hope this can help!

holgerroth commented 1 year ago

That's interesting. So, the problem only comes up when using the pretrained CNN? FedOpt seems to be more sensitive to this initialization.

Have you tried reducing the local aggregation_epochs?

LeandroDiL commented 1 year ago

Ye, even reducing the local epochs of each client the behaviour stays the same (obviously worst due to the less epochs). I also tried using MobieNetv2 and ResNet18 with the same settings explained before but without param weights='DEFAULT' and it results in a static 0.1 (in the section val_acc_global_model).

holgerroth commented 1 year ago

Hi @LeandroDiL, do you have any updates on this topic?

siomvas commented 1 year ago

Hi @holgerroth, I can confirm there is problematic behaviour when using anything other than the ModerateCNN and SimpleCNN. Global model validation metrics get stuck at 0.1 from the first round of aggregation.

holgerroth commented 1 year ago

I see. Can you specify what models and alpha setting you are using? Are the same models working fine with FedAvg and the same alpha setting on CIFAR-10?

siomvas commented 1 year ago

Yes, this is with alpha 0.6. FedAvg & FedProx work fine. It's a dozen of models, from a ResNet-20 to a couple of Transformers, all of them break under FedOpt except for ModerateCNN and to some extent SimpleCNN. SimpleCNN underperforms, but at least it converges. The rest get stuck in terms of global model validation accuracy, but locally they do learn (local validation accuracy increases between agreggations). All trained from scratch.

holgerroth commented 1 year ago

Ok. Have you tried different learning rates and momentum for the fedopt optimizer, maybe even some optimizers other than SGD? lr 1 and momentum 0 should behave identically to FedAvg with SGD optimizer.

siomvas commented 1 year ago

I have not tried multiple settings but lr 1, no momentum and no scheduler, which should be identical to FedAvg, is getting stuck around 0.1. No errors in the logger, and the same script runs fine when using FedAvg/FedProx.

Edit:

What is interesting is that the training loss and the local model accuracies are correct (yellow line is FedAvg, orange is the FedOpt equivalent):

image image

Meanwhile the global model malfunctions, but somehow that is not propagated to the next round's local models?

image
YuanTingHsieh commented 1 year ago

@siomvas thanks for more information, @holgerroth is going on a vacation now, he will be back and reply later.

holgerroth commented 1 year ago

Hi @LeandroDiL, @siomvas, I'm looking into this issue now. Just to confirm, have you also changed the model configuration in config_fed_server.json when running these experiments? Please attach your job configurations and code if possible.

holgerroth commented 1 year ago

Okay, I was able to reproduce the behavior. It has to do with the batch norm layers of these more complex models. When updating the global model using SGD, the batch norm parameters are actually not included in self.model.named_parameters(), and therefore the optimizer doesn't update them.

The FedOpt paper also uses group norm instead of batch norm to avoid these kinds of issues: "We train a modified ResNet-18 on both datasets, where the batch normalization layers are replaced by group normalization layers (Wu & He, 2018). We use two groups in each group normalization layer. As shown by Hsieh et al. (2019), group normalization can lead to significant gains in accuracy over batch normalization in federated settings."

I provided a workaround for this issue by updating the batch norm parameters using FedAvg and only updating the trainable parameters using the FedOpt optimizer for the global model: https://github.com/NVIDIA/NVFlare/pull/1851

siomvas commented 1 year ago

I pinpointed this issue/bug when trying to use SCAFFOLD since that actually (conveniently) breaks, so I could see where the error was, I will open a new bug report for that. This is what I found:

The issue is not with batch norm itself, but with the running stats:

>>> [k for k,v in mobile.named_parameters()][:5]
['conv1.weight', 'bn1.weight', 'bn1.bias', 'layers.0.conv1.weight', 'layers.0.bn1.weight']
>>> [k for k in mobile.state_dict()][:7]
['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layers.0.conv1.weight']

The weight and bias of BN are getting averaged, but the running stats don't, causing a non-sensical layer (note that num_batches_tracked is not used in any calculation in the default setting where BN uses momentum instead), as these are learned in-tandem client-side.

It also applies to other architectural elements too; SWIN has a relative_position_index parameter that is an integer and was also causing the same issue, it does not converge with FedOpt despite it using LayerNorm.

Correct me if I'm wrong but it seems with #1851 the weights and biases are still getting "FedOpted", while the running stats get averaged, so this should not be expected behaviour as there will be a mismatch.

A quick test with FedAdam using the proposed hparams from the FedOpt paper (client lr=0.03, server lr=0.01) using #1851 shows there is convergence, but how the mismatch in the affected layers affects model performance is unclear.

To investigate further, I tried combining FedOpt with FedBN, implemented via the task filter mechanism (adding Exclude_vars for bn parameters). But it seems currently there is another bug where FedOpt does not respect task filters. See fix which can be added in #1851.

        for name, param in self.model.named_parameters():
            param.grad = torch.tensor(-1.0 * model_diff[name]).to(self.device)
            updated_params.append(name)

should be

        for name, param in self.model.named_parameters():
            if name in model_diff:
                param.grad = torch.tensor(-1.0 * model_diff[name]).to(self.device)
                updated_params.append(name)

I believe this should remain open as not a bug but a documented issue.

holgerroth commented 1 year ago

Hi @siomvas, thanks for your test and additional info. Yes, the desired behavior of batch norm layers with FedOpt is somewhat unclear. That's why many try to avoid using batch norm in FL settings as in the FedOpt paper and why I used "workaround" to describe #1851 as it will use FedOpt to optimize the global trainable parameters but use FedAvg to update any other layers such as batch norm statistics. It needs to be seen if this approach also works with SWIN architectures.

I know it's inconvenient as most of the pretrained torchvision models use batch norm but I would recommend looking into models that use group norm instead.

Thanks for pointing out the issue with using filters, I added that fix to the PR. I also updated the doc string to document the behavior when using batch norm. It's acceptable to me as we can match the performance of FedAvg using this workaround and the equivalent SGD settings (lr=1, momentum=0).

BitCalSaul commented 10 months ago

Hi, @siomvas I found the similar situation as yours. I used Adam as an optimizer and Swin as the model for Cifar10. However, with the first epoch done, the loss, acc1, and acc5 never got better. I changed to a much smaller model ResNet56 from timm, and the results got very good as expected. It's really strange that Swin didn't work at this case. The loss and evaluation are shown as below.

image image