KarhouTam / FL-bench

Benchmark of federated learning. Dedicated to the community. 🤗
GNU General Public License v3.0
503 stars 82 forks source link

Seeking Help with Reproducing Results on CIFAR10 through FedAvgCNN #87

Closed Arnou1 closed 2 months ago

Arnou1 commented 2 months ago

Like the similar question earlier, I am struggling with reproducing the results on CIFAR10 as described in the original FedAvg paper.

Though many hyperparameter combinations have been attempted, the best result I have ever got so far is at around 60% accuracy (IID data on 100 clients with FedAvg), which is way worse than what is reported in the paper. I only trained the model for 300 rounds. However, with the trend shown in the figure down below, I doubt there would be any significant improvement even if the model is trained longer.

I have attached my configurations. Any suggestions on improving the test results would be greatly appreciated. Thanks.

==================== FedAvg ====================                                                                                                                                       
Experiment Arguments:                                                                                                                                                                  
{                                                                                                                                                                                      
  "mode": "serial",                                                                                                                                                                    
  "common": {                                                                                                                                                                          
    "dataset": "cifar10",                                                                                                                                                              
    "seed": 108,                                                                                                                                                                       
    "model": "avgcnn",                                                                                                                                                                 
    "join_ratio": 0.1,                                                                                                                                                                 
    "global_epoch": 300,                                                                                                                                                               
    "local_epoch": 5,                                                                                                                                                                  
    "finetune_epoch": 0,                                                                                                                                                               
    "batch_size": 32,                                                                                                                                                                  
    "test_interval": 100,                                                                                                                                                              
    "straggler_ratio": 0,                                                                                                                                                              
    "straggler_min_local_epoch": 0,                                                                                                                                                    
    "external_model_params_file": null,                                                                                                                                                
    "buffers": "local",                                                                                                                                                                
    "optimizer": {                                                                                                                                                                     
      "lr": 0.001,                                                                                                                                                                     
      "weight_decay": 0,                                                                                                                                                               
      "betas": [                                                                                                                                                                       
        0.9,                                                                                                                                                                           
        0.999                                                                                                                                                                          
      ],                                                                                                                                                                               
      "amsgrad": false,                                                                                                                                                                
      "name": "adam"                                                                                                                                                                   
    },                      

shot

KarhouTam commented 2 months ago

Actually, it confuses me too. I think the poor performance of FedAvgCNN is attributed to the lack of normalization. Normalization really help improve model training. But in the original paper, I don't see any normalization in FedAvgCNN's inference.

So if full reproduction is not your goal, maybe you can try adding normalization layers in the model or just switch to using other models.

Arnou1 commented 2 months ago

I just added some data augmentation to the CIFAR10 dataset and it improved the test accuracy by around 15%. I will now experiment with normalization layers. Thanks!