Open wittenator opened 5 days ago
I don't find the model architecture used in the Flower run. My idea is start from comparing model architecture.
BTW, I've trained resnet18 w and w/o public pretrained model weights and find that w/o pretrained weights' model has worse performance (cifar10 -a 10 -tr 0.25 -cn 100
):
w/o pretrained weights
w pretrained weights
Well, at least this proves pretrained weights really help training...
And seems when has more clients (like 100), the training process is more stable?
I haven't test it with 10 clients. Maybe tonight?
Config
method: fedavg
dataset:
name: cifar10
model:
name: lenet5
use_torchvision_pretrained_weights: false
external_model_weights_path: null
lr_scheduler:
name: null
step_size: 10
gamma: 0.1
T_max: 10
eta_min: 0
factor: 0.3334
total_iters: 5
mode: min
patience: 10
threshold: 0.0001
threshold_mode: rel
cooldown: 0
min_lr: 0
eps: 1.0e-08
last_epoch: -1
optimizer:
name: sgd
lr: 0.01
dampening: 0
weight_decay: 0
momentum: 0
alpha: 0.99
nesterov: false
betas:
- 0.9
- 0.999
amsgrad: false
mode: parallel
parallel:
ray_cluster_addr: null
num_cpus: null
num_gpus: null
num_workers: 2
common:
seed: 42
join_ratio: 0.1
global_epoch: 100
local_epoch: 5
finetune_epoch: 0
batch_size: 32
test_interval: 100
straggler_ratio: 0
straggler_min_local_epoch: 0
buffers: global
eval_test: true
eval_val: false
eval_train: false
verbose_gap: 10
visible: visdom
use_cuda: true
save_log: true
save_model: false
save_fig: true
save_metrics: true
delete_useless_run: true
fedprox:
mu: 0.01
pfedsim:
warmup_round: 0.5
BTW, I've left some comments and requests on your PR #104. That PR is closed to be merged. Maybe we can push it further, wdyt?
The model is defined here: https://github.com/wittenator/flower/blob/edf72b2b280257132d14448b323a1a0d3e3102be/baselines/fedprox/pyproject.toml#L133 But I paid attention to setting the Flower experiment up in the same way FL-Bench is setup. Is this actually your config since lenet and no momentum is included there? The pretrained weights plots is even weirder, since the preLocalTraining curve is above the postLocalTraining curve, right? This should mean that the local training makes the clients worse at predicting on their own dataset.(Besides converging to a worse accuracy overall). What I think is curious is that if you use optimizer resetting, Adam somehow fixes the bug that is present in the setup. I am actually out of ideas where this may stem from. I'll try to finish the PR, but time is currently tight.
Sorry, my fault.
Here is the right one (resnet18, cifar10, dir(10), 100 clients, w pretrained weights
)
Seems there is something wrong...
I am going to traverse the framework and FL training workflow. Maybe I can find the reason.
Dataset: python generate_data.py -d cifar10 -cn 10 --iid 1 -tr 0.25
I changed some codes in ResNet
:
class ResNet(DecoupledModel):
archs = {
"18": (models.resnet18, models.ResNet18_Weights.DEFAULT),
"34": (models.resnet34, models.ResNet34_Weights.DEFAULT),
"50": (models.resnet50, models.ResNet50_Weights.DEFAULT),
"101": (models.resnet101, models.ResNet101_Weights.DEFAULT),
"152": (models.resnet152, models.ResNet152_Weights.DEFAULT),
}
def __init__(self, version, dataset, pretrained):
super().__init__()
# NOTE: If you don't want parameters pretrained, set `pretrained` as False
resnet: models.ResNet = self.archs[version][0](
weights=self.archs[version][1] if pretrained else None,
num_classes=NUM_CLASSES[dataset],
)
self.base = resnet
# self.classifier = nn.Linear(self.base.fc.in_features, NUM_CLASSES[dataset])
self.classifier = nn.Identity()
def forward(self, x: Tensor) -> Tensor:
# if input is grayscale, repeat it to 3 channels
if x.shape[1] == 1:
x = x.broadcast_to(x.shape[0], 3, *x.shape[2:])
return super().forward(x)
Config
method: fedavg
dataset:
# [mnist, cifar10, cifar100, emnist, fmnist, femnist, medmnist,
# medmnistA, medmnistC, covid19, celeba, synthetic, svhn,
# tiny_imagenet, cinic10, domain]
name: cifar10
model:
name: res18
# Whether to use torchvision integrated model weights.
# Has no effect if model is lenet5, 2nn or fedavgcnn
use_torchvision_pretrained_weights: false
# The model parameters `.pt` file relative path to the directory of FL-bench.
# This feature is enabled only when `unique_model=False`,
# which is pre-defined and fixed by each FL method.
external_model_weights_path: null
# The learning rate scheduler that used for client local training.
# Can be null if no lr_scheduler is needed.
lr_scheduler:
name: null # [null, step, cosine, constant, plateau]
step_size: 10 # step
gamma: 0.1 # [step, plateau]
T_max: 10 # cosine
eta_min: 0 # cosine
factor: 0.3334 # [constant, plateau]
total_iters: 5 # constant
mode: min # plataeu
patience: 10 # plateau
threshold: 1.0e-4 # plateau
threshold_mode: rel # plateau
cooldown: 0 # plateau
min_lr: 0 # plateau
eps: 1.0e-8 # plateau
last_epoch: -1
# The optimizer that used for client local training.
optimizer:
name: sgd # [sgd, adam, adamw, rmsprop, adagrad]
lr: 0.01
dampening: 0 # for SGD
weight_decay: 0
momentum: 0 # for [SGD, RMSprop]
alpha: 0.99 # for RMSprop
nesterov: false # for SGD
betas: [0.9, 0.999] # for [Adam, AdamW]
amsgrad: false # for [Adam, AdamW]
mode: parallel # [serial, parallel]
# It's fine to keep these configs. if mode is 'serial', these configs will be ignored.
parallel:
# The IP address of the selected ray cluster.
# Default as null, which means if there is no existing ray cluster,
# then Ray will create a new cluster at the beginning of the experiment
# and destroy it at the end.
# More details can be found in https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html.
ray_cluster_addr: null # [null, auto, local]
# The amount of computational resources you allocate for your Ray cluster.
# Default as null for all.
num_cpus: null
num_gpus: null
# Should be set larger than 1, or training mode fallback to `serial`
# Set a larger `num_workers` can further boost efficiency,
# but also let each worker have less computational resources.
num_workers: 2
common:
seed: 42 # Random seed of the run.
join_ratio: 1 # Ratio for (client each round) / (client num in total).
global_epoch: 100 # Number of global epochs, also called communication round.
local_epoch: 5 # Number if epochs of client local training.
finetune_epoch: 0 # Number of epochs of clients fine-tunning their models before test.
batch_size: 32 # Data batch size for client local training.
test_interval: 100 # Interval round of performing test on all test clients.
# The ratio of stragglers (set in [0, 1]).
# Stragglers would not perform full-epoch local training as normal clients.
# Their local epoch would be randomly selected from range [straggler_min_local_epoch, local_epoch).
straggler_ratio: 0
straggler_min_local_epoch: 0
# How to deal with parameter buffers (in model.buffers()) of each client model.
# global (default): buffers will be aggregated like other model parameters.
# local: clients' buffers are isolated.
# drop: clients will drop their buffers after training done.
buffers: global # [local, global, drop]
# Set eval_<...> as true for performing evaluation on <...>sets held by
# this round's joined clients before and after local training.
eval_test: true
eval_val: false
eval_train: false
verbose_gap: 10 # Interval round of displaying clients training performance on terminal.
visible: tensorboard # [null, visdom, tensorboard]
use_cuda: true # Whether to use cuda for training.
save_log: true # Whether to save log files in out/<method>/<start_time>.
save_model: false # Whether to save model weights (*.pt) in out/<method>/<start_time>.
save_fig: true # Whether to save learning curve firgure (*.png) in out/<method>/<start_time>.
save_metrics: true # Whether to save metrics (*.csv) in out/<method>/<start_time>.
# Whether to delete output files after user press `Ctrl + C`,
# which indicates that the run is removable.
delete_useless_run: true
# You can set specific arguments for advanced FL methods also.
# FL-bench uses FL method arguments by args.<method>.<arg>.
# You need to follow the key set in `get_hyperparams()` in class <method>Server, src/server/<method>.py
# FL-bench will ignore these arguments if they are not supported by the selected method,
# e.g., if you are running FedProx, then pfedsim arguments will be ignored.
fedprox:
mu: 0.01
pfedsim:
warmup_round: 0.5
fedap:
version: f
About why AfterLocalTraining performs worse, more tests are needed... Maybe is the buffer?
Or if you can modify the flower run's code and plot the afterLocalTraining
curve also?
Ohhh that looks much better!! But hmmm, do you have an explanation for the fix? Did the other classifier layer not get registered or something? I'll plot the afterLocalTraining as well tomorrow.
Maybe the key is model parameter initialization, since I don't change other code and the training procedure remains the same. I'll do more tests for it.
I just tried to reproduce this fix with Dir(100.0) instead of IID 1.0. IID seems to work fine, but Dir(100.0) is the same as before. So is this perhaps a data/partitioning problem?
Ah yes, I just tested this and the old version of the model performs just as good with IID 1. I think that this is a heavy pointer into the direction that the Dirichlet sampler or one of the steps afterwards has a bug. Maybe it makes sense to export the sampling from the Flower partitioner and import it into FL-Bench?
Do the test? Sure, but maybe some format transformation works are needed.
I think that this is a heavy pointer into the direction that the Dirichlet sampler or one of the steps afterwards has a bug.
I'll check it asap.
Yeah, I found actually some bugs in the dirichlet partitioning method (the min_size
code indentation is wrong...).
I followed the flower implementation and change the key loop to:
while min_size < least_samples:
# Initialize data indices for each client
partition["data_indices"] = [[] for _ in range(client_num)]
# Iterate over each label in the label set
for label in label_set:
# Shuffle the indices associated with the current label
np.random.shuffle(indices_4_labels[label])
# Generate a Dirichlet distribution for splitting data among clients
distribution = np.random.dirichlet(np.repeat(alpha, client_num))
# Calculate split indices based on the generated distribution
cumulative_indices = np.cumsum(distribution) * len(indices_4_labels[label])
split_indices_position = cumulative_indices.astype(int)[:-1]
# Split the indices for the current label
split_indices = np.split(indices_4_labels[label], split_indices_position)
# Assign split indices to each client
for client_id in range(client_num):
partition["data_indices"][client_id].extend(split_indices[client_id])
# Update the minimum size of the data across all clients
min_size = min(len(idx) for idx in partition["data_indices"])
Maybe try this?
Fixed by commit e6e0507965107d1acfc4854e5c74a326ac5afb3d
Did it fix the behavior for you? I did a quick test yesterday, but it didn't work any better.
Not been tested yet. But now the partitioner code is doing almost the same thing as Flower's does.
Just for reference: This is the same config as above with the new dirichlet sampler:
Have you tested FL-bench's FedAvg (and res18
) with Flower's dirichlet partitioner output?
If not, by manually changging the partition.pkl
file is the fatest way.
Not yet, but I wrote a script that extracts the indices from a Flower partitioner and saves it to a valid partition.pkl file. I'll test it tomorrow most probably
Describe the bug
Currently FedAvg does not perform well even in homogeneous settings for models with batchnorms such as resnet18. The accuracy over epochs curve is highly erratic and does not reach the accuracy of comparable implementations such as Flower. Since almost all other methods in FL-Bench are derived from the FedAvg code, this may affect all other methods as well.
Current main code without optimizer resetting (grey) and with optimizer resetting(blue):
Flower implementation with the same hyperparameters (code for the Flower implementation: https://github.com/wittenator/flower/tree/rework_fedprox_baseline/baselines/fedprox | it implements Fedprox, but is equivalent to FedAvg for mu=0):
To Reproduce
Commands used:
Config used:
Expected behavior
For iid data the accuracy over epochs curve should be a smooth, strictly increasing function that approaches a limit point at around 70-80% accuracy for resnet18 and cifar10.