KarhouTam / FL-bench

Benchmark of federated learning. Dedicated to the community. 🤗
GNU General Public License v3.0
503 stars 81 forks source link

FedAvg does not converge on homogeneous data #110

Open wittenator opened 1 month ago

wittenator commented 1 month ago

Describe the bug

Currently FedAvg does not perform well even in homogeneous settings for models with batchnorms such as resnet18. The accuracy over epochs curve is highly erratic and does not reach the accuracy of comparable implementations such as Flower. Since almost all other methods in FL-Bench are derived from the FedAvg code, this may affect all other methods as well.

Current main code without optimizer resetting (grey) and with optimizer resetting(blue): image

Flower implementation with the same hyperparameters (code for the Flower implementation: https://github.com/wittenator/flower/tree/rework_fedprox_baseline/baselines/fedprox | it implements Fedprox, but is equivalent to FedAvg for mu=0):

image

To Reproduce

Commands used:

python generate_data.py -d cifar10 -a 100.0 -cn 10
python main.py method=fedavg

Config used:

method: fedavg

dataset:
  # [mnist, cifar10, cifar100, emnist, fmnist, femnist, medmnist,
  # medmnistA, medmnistC, covid19, celeba, synthetic, svhn,
  # tiny_imagenet, cinic10, domain]
  name: cifar10

model:
  name: res18

  # Whether to use torchvision integrated model weights.
  # Has no effect if model is lenet5, 2nn or fedavgcnn
  use_torchvision_pretrained_weights: false

  # The model parameters `.pt` file relative path to the directory of FL-bench.
  # This feature is enabled only when `unique_model=False`,
  # which is pre-defined and fixed by each FL method.
  external_model_weights_path: null

# The learning rate scheduler that used for client local training.
# Can be null if no lr_scheduler is needed.
lr_scheduler:
  name: null # [null, step, cosine, constant, plateau]
  step_size: 10 # step
  gamma: 0.1 # [step, plateau]
  T_max: 10 # cosine
  eta_min: 0 # cosine
  factor: 0.3334 # [constant, plateau]
  total_iters: 5 # constant
  mode: min # plataeu
  patience: 10 # plateau
  threshold: 1.0e-4 # plateau
  threshold_mode: rel # plateau
  cooldown: 0 # plateau
  min_lr: 0 # plateau
  eps: 1.0e-8 # plateau
  last_epoch: -1

# The optimizer that used for client local training.
optimizer:
  name: sgd # [sgd, adam, adamw, rmsprop, adagrad]
  lr: 0.01
  dampening: 0 # for SGD
  weight_decay: 0
  momentum: 0.9 # for [SGD, RMSprop]
  alpha: 0.99 # for RMSprop
  nesterov: false # for SGD
  betas: [0.9, 0.999] # for [Adam, AdamW]
  amsgrad: false # for [Adam, AdamW]

mode: parallel # [serial, parallel]
# It's fine to keep these configs. if mode is 'serial', these configs will be ignored.
parallel:
  # The IP address of the selected ray cluster.
  # Default as null, which means if there is no existing ray cluster,
  # then Ray will create a new cluster at the beginning of the experiment
  # and destroy it at the end.
  # More details can be found in https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html.
  ray_cluster_addr: null # [null, auto, local]

  # The amount of computational resources you allocate for your Ray cluster.
  # Default as null for all.
  num_cpus: null
  num_gpus: 1

  # Should be set larger than 1, or training mode fallback to `serial`
  # Set a larger `num_workers` can further boost efficiency,
  # but also let each worker have less computational resources.
  num_workers: 10

common:
  seed: 42 # Random seed of the run.
  join_ratio: 1.0 # Ratio for (client each round) / (client num in total).
  global_epoch: 50 # Number of global epochs, also called communication round.
  local_epoch: 5 # Number if epochs of client local training.
  finetune_epoch: 0 # Number of epochs of clients fine-tunning their models before test.
  batch_size: 32 # Data batch size for client local training.
  test_interval: 100 # Interval round of performing test on all test clients.

  # The ratio of stragglers (set in [0, 1]).
  # Stragglers would not perform full-epoch local training as normal clients.
  # Their local epoch would be randomly selected from range [straggler_min_local_epoch, local_epoch).
  straggler_ratio: 0
  straggler_min_local_epoch: 0

  # How to deal with parameter buffers (in model.buffers()) of each client model.
  # global (default): buffers will be aggregated like other model parameters.
  # local: clients' buffers are isolated.
  # drop: clients will drop their buffers after training done.
  buffers: global # [local, global, drop]

  # Set eval_<...> as true for performing evaluation on <...>sets held by
  # this round's joined clients before and after local training.
  eval_test: true
  eval_val: false
  eval_train: false

  verbose_gap: 10 # Interval round of displaying clients training performance on terminal.
  visible: tensorboard # [null, visdom, tensorboard]
  use_cuda: true # Whether to use cuda for training.

  save_log: true # Whether to save log files in out/<method>/<start_time>.
  save_model: false # Whether to save model weights (*.pt) in out/<method>/<start_time>.
  save_fig: true # Whether to save learning curve firgure (*.png) in out/<method>/<start_time>.
  save_metrics: true # Whether to save metrics (*.csv) in out/<method>/<start_time>.

  # Whether to delete output files after user press `Ctrl + C`,
  # which indicates that the run is removable.
  delete_useless_run: true

Expected behavior

For iid data the accuracy over epochs curve should be a smooth, strictly increasing function that approaches a limit point at around 70-80% accuracy for resnet18 and cifar10.

KarhouTam commented 1 month ago

I don't find the model architecture used in the Flower run. My idea is start from comparing model architecture.

BTW, I've trained resnet18 w and w/o public pretrained model weights and find that w/o pretrained weights' model has worse performance (cifar10 -a 10 -tr 0.25 -cn 100):

w/o pretrained weights

description

w pretrained weights

description

Well, at least this proves pretrained weights really help training...

And seems when has more clients (like 100), the training process is more stable?

I haven't test it with 10 clients. Maybe tonight?

Config

method: fedavg
dataset:
  name: cifar10
model:
  name: lenet5
  use_torchvision_pretrained_weights: false
  external_model_weights_path: null
lr_scheduler:
  name: null
  step_size: 10
  gamma: 0.1
  T_max: 10
  eta_min: 0
  factor: 0.3334
  total_iters: 5
  mode: min
  patience: 10
  threshold: 0.0001
  threshold_mode: rel
  cooldown: 0
  min_lr: 0
  eps: 1.0e-08
  last_epoch: -1
optimizer:
  name: sgd
  lr: 0.01
  dampening: 0
  weight_decay: 0
  momentum: 0
  alpha: 0.99
  nesterov: false
  betas:
  - 0.9
  - 0.999
  amsgrad: false
mode: parallel
parallel:
  ray_cluster_addr: null
  num_cpus: null
  num_gpus: null
  num_workers: 2
common:
  seed: 42
  join_ratio: 0.1
  global_epoch: 100
  local_epoch: 5
  finetune_epoch: 0
  batch_size: 32
  test_interval: 100
  straggler_ratio: 0
  straggler_min_local_epoch: 0
  buffers: global
  eval_test: true
  eval_val: false
  eval_train: false
  verbose_gap: 10
  visible: visdom
  use_cuda: true
  save_log: true
  save_model: false
  save_fig: true
  save_metrics: true
  delete_useless_run: true
fedprox:
  mu: 0.01
pfedsim:
  warmup_round: 0.5
KarhouTam commented 1 month ago

BTW, I've left some comments and requests on your PR #104. That PR is closed to be merged. Maybe we can push it further, wdyt?

wittenator commented 1 month ago

The model is defined here: https://github.com/wittenator/flower/blob/edf72b2b280257132d14448b323a1a0d3e3102be/baselines/fedprox/pyproject.toml#L133 But I paid attention to setting the Flower experiment up in the same way FL-Bench is setup. Is this actually your config since lenet and no momentum is included there? The pretrained weights plots is even weirder, since the preLocalTraining curve is above the postLocalTraining curve, right? This should mean that the local training makes the clients worse at predicting on their own dataset.(Besides converging to a worse accuracy overall). What I think is curious is that if you use optimizer resetting, Adam somehow fixes the bug that is present in the setup. I am actually out of ideas where this may stem from. I'll try to finish the PR, but time is currently tight.

KarhouTam commented 4 weeks ago

Sorry, my fault. Here is the right one (resnet18, cifar10, dir(10), 100 clients, w pretrained weights)

image

Seems there is something wrong...

I am going to traverse the framework and FL training workflow. Maybe I can find the reason.

KarhouTam commented 4 weeks ago

Dataset: python generate_data.py -d cifar10 -cn 10 --iid 1 -tr 0.25

I changed some codes in ResNet:

class ResNet(DecoupledModel):
    archs = {
        "18": (models.resnet18, models.ResNet18_Weights.DEFAULT),
        "34": (models.resnet34, models.ResNet34_Weights.DEFAULT),
        "50": (models.resnet50, models.ResNet50_Weights.DEFAULT),
        "101": (models.resnet101, models.ResNet101_Weights.DEFAULT),
        "152": (models.resnet152, models.ResNet152_Weights.DEFAULT),
    }

    def __init__(self, version, dataset, pretrained):
        super().__init__()

        # NOTE: If you don't want parameters pretrained, set `pretrained` as False
        resnet: models.ResNet = self.archs[version][0](
            weights=self.archs[version][1] if pretrained else None,
            num_classes=NUM_CLASSES[dataset],
        )
        self.base = resnet
        # self.classifier = nn.Linear(self.base.fc.in_features, NUM_CLASSES[dataset])
        self.classifier = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        # if input is grayscale, repeat it to 3 channels
        if x.shape[1] == 1:
            x = x.broadcast_to(x.shape[0], 3, *x.shape[2:])
        return super().forward(x)

Config

method: fedavg

dataset:
  # [mnist, cifar10, cifar100, emnist, fmnist, femnist, medmnist,
  # medmnistA, medmnistC, covid19, celeba, synthetic, svhn,
  # tiny_imagenet, cinic10, domain]
  name: cifar10

model:
  name: res18

  # Whether to use torchvision integrated model weights.
  # Has no effect if model is lenet5, 2nn or fedavgcnn
  use_torchvision_pretrained_weights: false

  # The model parameters `.pt` file relative path to the directory of FL-bench.
  # This feature is enabled only when `unique_model=False`,
  # which is pre-defined and fixed by each FL method.
  external_model_weights_path: null

# The learning rate scheduler that used for client local training.
# Can be null if no lr_scheduler is needed.
lr_scheduler:
  name: null # [null, step, cosine, constant, plateau]
  step_size: 10 # step
  gamma: 0.1 # [step, plateau]
  T_max: 10 # cosine
  eta_min: 0 # cosine
  factor: 0.3334 # [constant, plateau]
  total_iters: 5 # constant
  mode: min # plataeu
  patience: 10 # plateau
  threshold: 1.0e-4 # plateau
  threshold_mode: rel # plateau
  cooldown: 0 # plateau
  min_lr: 0 # plateau
  eps: 1.0e-8 # plateau
  last_epoch: -1

# The optimizer that used for client local training.
optimizer:
  name: sgd # [sgd, adam, adamw, rmsprop, adagrad]
  lr: 0.01
  dampening: 0 # for SGD
  weight_decay: 0
  momentum: 0 # for [SGD, RMSprop]
  alpha: 0.99 # for RMSprop
  nesterov: false # for SGD
  betas: [0.9, 0.999] # for [Adam, AdamW]
  amsgrad: false # for [Adam, AdamW]

mode: parallel # [serial, parallel]
# It's fine to keep these configs. if mode is 'serial', these configs will be ignored.
parallel:
  # The IP address of the selected ray cluster.
  # Default as null, which means if there is no existing ray cluster,
  # then Ray will create a new cluster at the beginning of the experiment
  # and destroy it at the end.
  # More details can be found in https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html.
  ray_cluster_addr: null # [null, auto, local]

  # The amount of computational resources you allocate for your Ray cluster.
  # Default as null for all.
  num_cpus: null
  num_gpus: null

  # Should be set larger than 1, or training mode fallback to `serial`
  # Set a larger `num_workers` can further boost efficiency,
  # but also let each worker have less computational resources.
  num_workers: 2

common:
  seed: 42 # Random seed of the run.
  join_ratio: 1 # Ratio for (client each round) / (client num in total).
  global_epoch: 100 # Number of global epochs, also called communication round.
  local_epoch: 5 # Number if epochs of client local training.
  finetune_epoch: 0 # Number of epochs of clients fine-tunning their models before test.
  batch_size: 32 # Data batch size for client local training.
  test_interval: 100 # Interval round of performing test on all test clients.

  # The ratio of stragglers (set in [0, 1]).
  # Stragglers would not perform full-epoch local training as normal clients.
  # Their local epoch would be randomly selected from range [straggler_min_local_epoch, local_epoch).
  straggler_ratio: 0
  straggler_min_local_epoch: 0

  # How to deal with parameter buffers (in model.buffers()) of each client model.
  # global (default): buffers will be aggregated like other model parameters.
  # local: clients' buffers are isolated.
  # drop: clients will drop their buffers after training done.
  buffers: global # [local, global, drop]

  # Set eval_<...> as true for performing evaluation on <...>sets held by
  # this round's joined clients before and after local training.
  eval_test: true
  eval_val: false
  eval_train: false

  verbose_gap: 10 # Interval round of displaying clients training performance on terminal.
  visible: tensorboard # [null, visdom, tensorboard]
  use_cuda: true # Whether to use cuda for training.

  save_log: true # Whether to save log files in out/<method>/<start_time>.
  save_model: false # Whether to save model weights (*.pt) in out/<method>/<start_time>.
  save_fig: true # Whether to save learning curve firgure (*.png) in out/<method>/<start_time>.
  save_metrics: true # Whether to save metrics (*.csv) in out/<method>/<start_time>.

  # Whether to delete output files after user press `Ctrl + C`,
  # which indicates that the run is removable.
  delete_useless_run: true

# You can set specific arguments for advanced FL methods also.
# FL-bench uses FL method arguments by args.<method>.<arg>.
# You need to follow the key set in `get_hyperparams()` in class <method>Server, src/server/<method>.py
# FL-bench will ignore these arguments if they are not supported by the selected method,
# e.g., if you are running FedProx, then pfedsim arguments will be ignored.
fedprox:
  mu: 0.01
pfedsim:
  warmup_round: 0.5
fedap:
  version: f

image

About why AfterLocalTraining performs worse, more tests are needed... Maybe is the buffer?

Or if you can modify the flower run's code and plot the afterLocalTraining curve also?

wittenator commented 4 weeks ago

Ohhh that looks much better!! But hmmm, do you have an explanation for the fix? Did the other classifier layer not get registered or something? I'll plot the afterLocalTraining as well tomorrow.

KarhouTam commented 4 weeks ago

Maybe the key is model parameter initialization, since I don't change other code and the training procedure remains the same. I'll do more tests for it.

wittenator commented 4 weeks ago

I just tried to reproduce this fix with Dir(100.0) instead of IID 1.0. IID seems to work fine, but Dir(100.0) is the same as before. So is this perhaps a data/partitioning problem?

wittenator commented 4 weeks ago

Ah yes, I just tested this and the old version of the model performs just as good with IID 1. I think that this is a heavy pointer into the direction that the Dirichlet sampler or one of the steps afterwards has a bug. Maybe it makes sense to export the sampling from the Flower partitioner and import it into FL-Bench?

KarhouTam commented 4 weeks ago

Do the test? Sure, but maybe some format transformation works are needed.

I think that this is a heavy pointer into the direction that the Dirichlet sampler or one of the steps afterwards has a bug.

I'll check it asap.

KarhouTam commented 3 weeks ago

Yeah, I found actually some bugs in the dirichlet partitioning method (the min_size code indentation is wrong...).

I followed the flower implementation and change the key loop to:

    while min_size < least_samples:
        # Initialize data indices for each client
        partition["data_indices"] = [[] for _ in range(client_num)]

        # Iterate over each label in the label set
        for label in label_set:
            # Shuffle the indices associated with the current label
            np.random.shuffle(indices_4_labels[label])

            # Generate a Dirichlet distribution for splitting data among clients
            distribution = np.random.dirichlet(np.repeat(alpha, client_num))

            # Calculate split indices based on the generated distribution
            cumulative_indices = np.cumsum(distribution) * len(indices_4_labels[label])
            split_indices_position = cumulative_indices.astype(int)[:-1]

            # Split the indices for the current label
            split_indices = np.split(indices_4_labels[label], split_indices_position)

            # Assign split indices to each client
            for client_id in range(client_num):
                partition["data_indices"][client_id].extend(split_indices[client_id])

        # Update the minimum size of the data across all clients
        min_size = min(len(idx) for idx in partition["data_indices"])

Maybe try this?

KarhouTam commented 3 weeks ago

Fixed by commit e6e0507965107d1acfc4854e5c74a326ac5afb3d

wittenator commented 3 weeks ago

Did it fix the behavior for you? I did a quick test yesterday, but it didn't work any better.

KarhouTam commented 3 weeks ago

Not been tested yet. But now the partitioner code is doing almost the same thing as Flower's does.

wittenator commented 3 weeks ago

Just for reference: This is the same config as above with the new dirichlet sampler: image

KarhouTam commented 3 weeks ago

Have you tested FL-bench's FedAvg (and res18) with Flower's dirichlet partitioner output?

If not, by manually changging the partition.pkl file is the fatest way.

wittenator commented 3 weeks ago

Not yet, but I wrote a script that extracts the indices from a Flower partitioner and saves it to a valid partition.pkl file. I'll test it tomorrow most probably

wittenator commented 4 days ago

I had another look and the problem is indeed the dirichlet partitioner. I added the option to use external partitioners from flwr_datasets in #127 and tested the setup with alpha=100.0. Only then the accuracy curves look as expected: image

Commands used:

python generate_data.py -d cifar10 -cn 10 --flower_partitioner_class "flwr_datasets.partitioner.DirichletPartitioner" --flower_partitioner_kwargs '{"alpha": 100.0, "partition_by": "label"}'
python  main.py dataset.name=cifar10 model.name=res18  model.use_torchvision_pretrained_weights=false common.test_server_interval=1 common.test_test=true common.join_ratio=1.0 mode=parallel parallel.num_workers=10 parallel.num_gpus=1 common.monitor=tensorboard
KarhouTam commented 4 days ago

It's REALLY a puzzle. 😂 Since I actually checked the source code of flower dirichlet partitioner and almost copied them to the FL-bench one. I cannot find any differences between them. What's the magic???