KarhouTam / FL-bench

Benchmark of federated learning. Dedicated to the community. 🤗
GNU General Public License v3.0
497 stars 79 forks source link

Datasets are not normalized correctly #88

Closed wittenator closed 1 week ago

wittenator commented 1 month ago

Describe the bug While playing around with the benchmark, I saw that at least some datasets are not correctly normalized. Is this intended behaviour?

To Reproduce

python generate_data.py -a 100.0 -cn 10 -d cifar10
python main.py method=fedavg

Printing the range of the image pixel after the dataloader reveals that the range is [-3,1060]. I think that there is currently a missing transform that normalizes the pixels from [0,255] to [0,1] before the transform.Normalize. But since the datasets are build in a way that the dataset data is wrapped in a tensor directly, I wanted to raise an issue here instead of hackily patching it on my side.

Expected behavior The pixels in the images after the transforms should be gaussian distributed. This probably affects not only Cifar10 and Cifar100.

KarhouTam commented 1 month ago

Thanks for pointing it out. I will check my codes later. Also, if you can show how you fix it, that would be great.

wittenator commented 1 month ago

How a clean fix would look like is a good question. Currently the datasets have lines like this

train_part = torchvision.datasets.CIFAR100(root, True, download=True)
test_part = torchvision.datasets.CIFAR100(root, False, download=True)
train_data = torch.Tensor(train_part.data).permute([0, -1, 1, 2]).float()
test_data = torch.Tensor(test_part.data).permute([0, -1, 1, 2]).float()

Was there a reason to extract the data from the dataset directly and not use the dataset itself?

KarhouTam commented 1 month ago

Was there a reason to extract the data from the dataset directly and not use the dataset itself?

torchvision doesn't support load the full dataset (trainset + testset) by one variable. So I need to load and concatenate them.

For the unreasonable pixel data range ([-3,1060] as you mentioned), maybe is because I need to rescale images to [0, 1] before performing Normalize().

wittenator commented 1 month ago

Is it maybe an option to wrap both of them in a torch.utils.data.ConcatDataset instead of extracting them by hand? This way we don't have to assume anything about the pixel values and can defer the decision completely to the transforms at hand.

KarhouTam commented 1 month ago

Seems reasonable. But some datasets aren't from torchvision, like DomainNet, TinyImageNet, ... I prefer a general solution.

So far, I have an idea is that utilizing super().__init__() and performing rescaling in __init__() of BaseDataset.

Pass all variables like data, targets, classes to super().__init__() as arguments and move super.__init__(data=..., targets=..., ...) to the end of each datasets' __init__(). Wdyt?

wittenator commented 1 month ago

That's definitively a way, although for DomainNet you would need a different solution anyway since the images are fetched on the fly. Btw, I did a dirty hack where I normalized the pixel values to see how that impacts performance and now I am not able to get even FedAVG to converge on non-i.i.d. datasets(alpha=0.1) or pretty much any of the methods. Do you have a set of parameters that worked well for your benchmarks? Now that the data range of the input is different I assume that things like learning rate and such have to be adapted accordingly, but even with an SGD with lr=0.0001 things looked wild. image

KarhouTam commented 1 month ago

FYI, first, FedAvg cannot handle most datasets with Dir(0.1) except easy datasets like MNIST, FashionMNIST, EMNIST... I have to admit that CIFAR-10 is pretty difficult in federated learning scenarios, especially with strong non-IID settings, e.g., Dir(0.1).

Do you have a set of parameters that worked well for your benchmarks?

No, there is no set of standard hyperparameters. But empirically, using a more complicated model and prolonging its training helps.

Now that the data range of the input is different I assume that things like learning rate and such have to be adapted accordingly, but even with an SGD with lr=0.0001 things looked wild.

First, this learning curve is pretty common with FedAvg in strong non-IID settings. In federated learning, momentum used in parameter aggregation can help training stability more than small learning rate (in my opinion).

wittenator commented 1 month ago

Hmm, but Rescaling by searching for the minimum and maximum in the specific dataset may also lead to errors right? Imagine a dataset that does not have a pure white or black pixel. Rescaling according to the dataset minmax, then rescales the dataset statistics as well. So the normalize transform itself is not accurate anymore. Maybe it makes sense to make use of torchvision.transforms.functional.to_tensor(pic) (which handles all that since it looks for the datatype of the data) ?

KarhouTam commented 1 month ago

Yeah, you're right. Maybe just simply dividing the whole image by 255.0 is the answer.

KarhouTam commented 1 month ago

Maybe it makes sense to make use of torchvision.transforms.functional.to_tensor(pic) (which handles all that since it looks for the datatype of the data) ?

Because some datasets are loaded as torch. Tensor already (like MNIST), so transforms.ToTensor() (the same as transforms.functional.to_tensor() maybe cannot do its work as expect.

wittenator commented 1 month ago

What is weird is that after the dataset fix even on iid datasets normal fedavg performs really bad. Usually I see much more smooth convergence curves for fedavg on iid datasets (even for cifar10). But over the last week I worked through pretty much the whole library and I can't come up with a reason for this. Do you have some intuition for this? image Here are my parameters:

{                                                                                                                                               
  "mode": "parallel",                                                                                                                           
  "common": {                                                                                                                                   
    "dataset": "cifar10",                                                                                                                       
    "seed": 42,                                                                                                                                 
    "model": "res18",                                                                                                                           
    "join_ratio": 1.0,                                                                                                                          
    "global_epoch": 100,                                                                                                                        
    "local_epoch": 5,                                                                                                                           
    "finetune_epoch": 0,                                                                                                                        
    "batch_size": 32,                                                                                                                           
    "test_interval": 10,                                                                                                                        
    "straggler_ratio": 0,                                                                                                                       
    "straggler_min_local_epoch": 0,                                                                                                             
    "external_model_params_file": null,                                                                                                         
    "buffers": "global",                                                                                                                        
    "optimizer": {                                                                                                                              
      "lr": 0.01,                                                                                                                               
      "dampening": 0,                                                                                                                           
      "weight_decay": 0.0001,                                                                                                                   
      "momentum": 0.9,                                                                                                                          
      "nesterov": false,                                                                                                                        
      "name": "sgd"                                                                                                                             
    },                                                                                                                                          
    "eval_test": true,                                                                                                                          
    "eval_val": false,                                                                                                                          
    "eval_train": false,                                                                                                                        
    "verbose_gap": 10,                                                                                                                          
    "visible": "tensorboard",                                                                                                                   
    "use_cuda": true,                                                                                                                           
    "save_log": true,                                                                                                                           
    "save_model": false,                                                                                                                        
    "save_fig": true,                                                                                                                           
    "save_metrics": true,                                                                                                                       
    "delete_useless_run": true                                                                                                                  
  },                                                                                                                                            
  "parallel": {                                                                                                                                 
    "ray_cluster_addr": null,                                                                                                                   
    "num_gpus": 1.0,                                                                                                                            
    "num_cpus": 20.0,                                                                                                                           
    "num_workers": 5                                                                                                                            
  },                                                                                                                                            
  "dataset": {                                                                                                                                  
    "client_num": 10,                                                                                                                           
    "test_ratio": 0.25,                                                                                                                         
    "val_ratio": 0.0,                                                                                                                           
    "seed": 42,                                                                                                                                 
    "split": "sample",                                                                                                                          
    "IID_ratio": 0.0,                                                                                                                           
    "monitor_window_name_suffix": "cifar10-10clients-0%IID-Dir(100.0)-seed42",                                                                  
    "alpha": 100.0,                                                                                                                             
    "least_samples": 40                                                                                                                         
  }                                                                                                                                             
}   
KarhouTam commented 1 month ago

Means that before I do these changes, at least FedAvg can work as expect in iid settings 😂?

Maybe I need to revert these changes and do more evaluations first...

I will evaluate my codes later.

wittenator commented 1 month ago

I did some more digging and found out that if you use the iid=1.0 parameter while generating the data, everything works flawlessly, but if you use the alpha=100.0 setting (which results in almost the same distribution) as seen in the images. Things fall apart suddenly. I added some central server evaluation. The upper line is with the iid parameter, the bottom two lines are with alpha=100.0. image

class_distribution class_distribution_dirichlet

KarhouTam commented 1 month ago

Well... That's a bit weird. Have you tested other algorithms with these settings?

KarhouTam commented 1 month ago

I've evaluated FedAvg on both -a 100 and --iid 1, 100 clients.

Setting

Two runs share the same experiment settings.

{                                                                                                                                   
  "mode": "parallel",                                                                                                               
  "common": {                                                                                                                       
    "dataset": "cifar10",                                                                                                           
    "seed": 42,                                                                                                                     
    "model": "lenet5",                                                                                                              
    "join_ratio": 0.1,                                                                                                              
    "global_epoch": 100,                                                                                                            
    "local_epoch": 5,                                                                                                               
    "finetune_epoch": 0,                                                                                                            
    "batch_size": 32,                                                                                                               
    "test_interval": 100,                                                                                                           
    "straggler_ratio": 0,                                                                                                           
    "straggler_min_local_epoch": 0,                                                                                                 
    "external_model_params_file": null,                                                                                             
    "buffers": "local",                                                                                                             
    "optimizer": {                                                                                                                  
      "lr": 0.01,                                                                                                                   
      "dampening": 0,                                                                                                               
      "weight_decay": 0,                                                                                                            
      "momentum": 0,                                                                                                                
      "nesterov": false,                                                                                                            
      "name": "sgd"                                                                                                                 
    },                                                                                                                              
    "eval_test": true,                                                                                                              
    "eval_val": false,                                                                                                              
    "eval_train": false,                                                                                                            
    "verbose_gap": 25,                                                                                                              
    "visible": "tensorboard",                                                                                                       
    "use_cuda": true,                                                                                                               
    "save_log": true,                                                                                                               
    "save_model": false,                                                                                                            
    "save_fig": true,                                                                                                               
    "save_metrics": true,                                                                                                           
    "delete_useless_run": true                                                                                                      
  },                                                                                                                                
  "parallel": {                                                                                                                     
    "ray_cluster_addr": null,                                                                                                       
    "num_gpus": 1.0,                                                                                                                
    "num_cpus": 16.0,                                                                                                               
    "num_workers": 2                                                                                                                
  },                                                                                                                                
  "dataset": {                                                                                                                      
    "client_num": 100,                                                                                                              
    "test_ratio": 0.25,                                                                                                             
    "val_ratio": 0.0,                                                                                                               
    "seed": 42,                                                                                                                     
    "split": "sample",                                                                                                              
    "IID_ratio": 0.0,                                                                                                               
    "monitor_window_name_suffix": "cifar10-100clients-0%IID-Dir(100.0)-seed42",                                                     
    "alpha": 100.0,                                                                                                                 
    "least_samples": 40                                                                                                             
  }                                                                                                                                 
}  

Results

image

We can see that the normalization did not sabotage the training.

wittenator commented 1 month ago

Hmmm, that is curious, could you check that again with res18 instead of lenet5 and go for global buffer instead of local? This is closer to a server FL setup instead of a personalization task then. I will try the same config as yours on my setup as well. Thanks for testing that already!

KarhouTam commented 1 month ago

Hmmm, that is curious, could you check that again with res18 instead of lenet5 and go for global buffer instead of local? This is closer to a server FL setup instead of a personalization task then.

Arguments

{
  "mode": "parallel",
  "common": {
    "dataset": "cifar10",
    "seed": 42,
    "model": "res18",
    "join_ratio": 0.1,
    "global_epoch": 100,
    "local_epoch": 5,
    "finetune_epoch": 0,
    "batch_size": 32,
    "test_interval": 100,
    "straggler_ratio": 0,
    "straggler_min_local_epoch": 0,
    "external_model_params_file": null,
    "buffers": "global",
    "optimizer": {
      "lr": 0.01,
      "dampening": 0,
      "weight_decay": 0,
      "momentum": 0,
      "nesterov": false,
      "name": "sgd"
    },
    "eval_test": true,
    "eval_val": false,
    "eval_train": false,
    "verbose_gap": 25,
    "visible": "tensorboard",
    "use_cuda": true,
    "save_log": true,
    "save_model": false,
    "save_fig": true,
    "save_metrics": true,
    "delete_useless_run": true
  },
  "parallel": {
    "ray_cluster_addr": null,
    "num_gpus": 1.0,
    "num_cpus": 16.0,
    "num_workers": 2
  },
  "dataset": {
    "client_num": 100,
    "test_ratio": 0.25,
    "val_ratio": 0.0,
    "seed": 42,
    "split": "sample",
    "IID_ratio": 1.0,
    "monitor_window_name_suffix": "cifar10-100clients-100%IID-seed42"
  }
}

Results

image

Basically, these curves can only tell you the approximate trend of training and whether it has collapsed. Because each point on the curve is only calculated by clients that join the training in that round, not all.

wittenator commented 1 month ago

Ohh that's interesting, I will investigate again if I can reproduce these graphs. I am aware of the bias of the evaluation (in comparison to a central evaluation), I actually added code for a central evaluation as well that I will include in a future PR. Weirdly enough, I saw the collapse in both the client evaluated graphs as well as in the central one. Maybe something went wrong on my end, I'll report more once I found something :)

wittenator commented 1 month ago

Btw, do you have an intuition why the two runs of yours have an almost 20% performance gap (even factoring in slight deviations in training data)?

KarhouTam commented 1 month ago

Hard to tell. I did more testings and found that maybe is not the data issue. Maybe it is related to model architecture.

Config

{                                                                                                                                                                                   
  "mode": "parallel",                                                                                                                                                               
  "common": {                                                                                                                                                                       
    "dataset": "cifar10",                                                                                                                                                           
    "seed": 42,                                                                                                                                                                     
    "model": "lenet5",                                                                                                                                                              
    "join_ratio": 0.1,                                                                                                                                                              
    "global_epoch": 100,                                                                                                                                                            
    "local_epoch": 5,                                                                                                                                                               
    "finetune_epoch": 0,                                                                                                                                                            
    "batch_size": 32,                                                                                                                                                               
    "test_interval": 100,                                                                                                                                                           
    "straggler_ratio": 0,                                                                                                                                                           
    "straggler_min_local_epoch": 0,                                                                                                                                                 
    "external_model_params_file": null,                                                                                                                                             
    "buffers": "local",                                                                                                                                                             
    "optimizer": {                                                                                                                                                                  
      "lr": 0.01,                                                                                                                                                                   
      "dampening": 0,                                                                                                                                                               
      "weight_decay": 0,                                                                                                                                                            
      "momentum": 0,                                                                                                                                                                
      "nesterov": false,                                                                                                                                                            
      "name": "sgd"                                                                                                                                                                 
    },                                                                                                                                                                              
    "eval_test": true,                                                                                                                                                              
    "eval_val": false,                                                                                                                                                              
    "eval_train": false,                                                                                                                                                            
    "verbose_gap": 25,                                                                                                                                                              
    "visible": "tensorboard",                                                                                                                                                       
    "use_cuda": true,                                                                                                                                                               
    "save_log": true,                                                                                                                                                               
    "save_model": false,                                                                                                                                                            
    "save_fig": true,                                                                                                                                                               
    "save_metrics": true,                                                                                                                                                           
    "delete_useless_run": true                                                                                                                                                      
  },                                                                                                                                                                                
  "parallel": {                                                                                                                                                                     
    "ray_cluster_addr": null,                                                                                                                                                       
    "num_gpus": 1.0,                                                                                                                                                                
    "num_cpus": 16.0,                                                                                                                                                               
    "num_workers": 2                                                                                                                                                                
  },                                                                                                                                                                                
  "dataset": {                                                                                                                                                                      
    "client_num": 100,                                                                                                                                                              
    "test_ratio": 0.25,                                                                                                                                                             
    "val_ratio": 0.0,                                                                                                                                                               
    "seed": 42,                                                                                                                                                                     
    "split": "sample",                                                                                                                                                              
    "IID_ratio": 1.0,                                                                                                                                                               
    "monitor_window_name_suffix": "cifar10-100clients-100%IID-seed42"                                                                                                               
  }                                                                                                                                                                                 
}   

Results

IID

{                                                                                                                                                                                   
    "100": {                                                                                                                                                                        
        "all_clients": {                                                                                                                                                            
            "test": {                                                                                                                                                               
                "loss": "1.2750 -> 0.0000",                                                                                                                                         
                "accuracy": "56.50% -> 0.00%"                                                                                                                                       
            }                                                                                                                                                                       
        }                                                                                                                                                                           
    }                                                                                                                                                                               
}  

Dir(1e+20), exactly IID (even more IID than the 100%IID in FL-bench 😂)

{                                                                                                                                                                                   
    "100": {                                                                                                                                                                        
        "all_clients": {                                                                                                                                                            
            "test": {                                                                                                                                                               
                "loss": "1.3090 -> 0.0000",                                                                                                                                         
                "accuracy": "56.26% -> 0.00%"                                                                                                                                       
            }                                                                                                                                                                       
        }                                                                                                                                                                           
    }                                                                                                                                                                               
}      

Curves image

Distributions

100%IID image

Dir(1e+20) image

wittenator commented 4 weeks ago

I did a little benchmark script in Flower as a reference and trained a resnet18 model in almost the same setting (same local epochs and data distribution alpha=100.0) and the FedAvg is very smooth and converges to a high accuracy: image

One difference between flower and fl-bench I saw up until now is that by default they use Adam as an optimizer and that they reset the optimizer on each global epoch. Afaik the averaging and everything else on the other hand is the same. I'll try resetting the optimizer on each global round. Let's see if this does something.

wittenator commented 3 weeks ago

Ok, I think that I got it. For non-personalization tasks not resetting the optimizer state reeeaaally messes with the optimization on the clients and how well the models average. I added the option to reset optimizer states on new global models and added some centralized evaluation resulting in these plots: image

I've made a PR with all the necessary additions. One curious thing still is that the process looks very clean with Adam as an optimizer, but is pretty bad still with SGD for multiple parameter combinations.

wittenator commented 2 weeks ago

@KarhouTam Did you already have time to look at the SGD + Fedavg problem by any chance?

KarhouTam commented 2 weeks ago

@wittenator Maybe this weekend or early next week. I'm still busy on other things this week.

KarhouTam commented 1 week ago

Hi, @wittenator . Now we can start talking about this issue.

wittenator commented 1 week ago

@KarhouTam Let's discuss this in #110 . I am currently running FedAvg with the current main with and without optimizer resetting and I'll append a Flower run with the same hyperparameters.