yuetan031 / FedProto

[AAAI'22] FedProto: Federated Prototype Learning across Heterogeneous Clients
132 stars 34 forks source link

About the CIFAR-10 dataset #3

Open Tam-JQK opened 2 years ago

Tam-JQK commented 2 years ago

This is an excellent piece of work!I like it. However, when I ran: python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 110 --train_shots_max 110 --ld 0.1 I found that the results on cifar-10 dataset were far from the results in the paper. I think my parameter settings are wrong, I read the paper carefully, but still haven't corrected the experimental results. So I wanted to borrow your parameter settings. Best regards.

yuetan031 commented 2 years ago

Hi,

As I mentioned in the experimental parts of this paper, a ResNet18 pre-trained on ImageNet is used as the initial model. “For CIFAR10, ResNet18 pre-trained on ImageNet (Krizhevsky, Sutskever, and Hinton 2017) is used as the initial model. The initial average test accuracy of the pre-trained network on CIFAR10 is 27.55%.”

Have you followed this setting when you reproduce the results?

973891422 commented 2 years ago

Hi,

As I mentioned in the experimental parts of this paper, a ResNet18 pre-trained on ImageNet is used as the initial model. “For CIFAR10, ResNet18 pre-trained on ImageNet (Krizhevsky, Sutskever, and Hinton 2017) is used as the initial model. The initial average test accuracy of the pre-trained network on CIFAR10 is 27.55%.”

Have you followed this setting when you reproduce the results?

Hi, I followed the settings in your code

            resnet = resnet18(args, pretrained=False, num_classes=args.num_classes)
            initial_weight = model_zoo.load_url(model_urls['resnet18'])
            local_model = resnet
            initial_weight_1 = local_model.state_dict()
            for key in initial_weight.keys():
                if key[0:3] == 'fc.' or key[0:5] == 'conv1' or key[
                        0:3] == 'bn1':
                    initial_weight[key] = initial_weight_1[key]

            local_model.load_state_dict(initial_weight)

But the result is still not good, I don't know where the problem is Your work is outstanding and I look forward to your guidance

yuetan031 commented 2 years ago

Hi, As I mentioned in the experimental parts of this paper, a ResNet18 pre-trained on ImageNet is used as the initial model. “For CIFAR10, ResNet18 pre-trained on ImageNet (Krizhevsky, Sutskever, and Hinton 2017) is used as the initial model. The initial average test accuracy of the pre-trained network on CIFAR10 is 27.55%.” Have you followed this setting when you reproduce the results?

Hi, I followed the settings in your code

            resnet = resnet18(args, pretrained=False, num_classes=args.num_classes)
            initial_weight = model_zoo.load_url(model_urls['resnet18'])
            local_model = resnet
            initial_weight_1 = local_model.state_dict()
            for key in initial_weight.keys():
                if key[0:3] == 'fc.' or key[0:5] == 'conv1' or key[
                        0:3] == 'bn1':
                    initial_weight[key] = initial_weight_1[key]

            local_model.load_state_dict(initial_weight)

But the result is still not good, I don't know where the problem is Your work is outstanding and I look forward to your guidance

Thanks for your attention. Have you checked the initial average test accuracy before training? If it is not around 27%, please double check the data splitting manner.

973891422 commented 2 years ago

Hi, As I mentioned in the experimental parts of this paper, a ResNet18 pre-trained on ImageNet is used as the initial model. “For CIFAR10, ResNet18 pre-trained on ImageNet (Krizhevsky, Sutskever, and Hinton 2017) is used as the initial model. The initial average test accuracy of the pre-trained network on CIFAR10 is 27.55%.” Have you followed this setting when you reproduce the results?

Hi, I followed the settings in your code

            resnet = resnet18(args, pretrained=False, num_classes=args.num_classes)
            initial_weight = model_zoo.load_url(model_urls['resnet18'])
            local_model = resnet
            initial_weight_1 = local_model.state_dict()
            for key in initial_weight.keys():
                if key[0:3] == 'fc.' or key[0:5] == 'conv1' or key[
                        0:3] == 'bn1':
                    initial_weight[key] = initial_weight_1[key]

            local_model.load_state_dict(initial_weight)

But the result is still not good, I don't know where the problem is Your work is outstanding and I look forward to your guidance

Thanks for your attention. Have you checked the initial average test accuracy before training? If it is not around 27%, please double check the data splitting manner.

Thanks for your reply, I will try it I have another question. When resnet18 is migrated to cifra10, the final fc layer is modified to output 10 dimensions. Can the random initialization of the fc layer parameters achieve the accuracy of about 27%?

973891422 commented 2 years ago

I ran it, resnet18 pre-trained model in cifar10 initial test accuracy rate is only about 5%, completely using the data division and pre-trained model loading method provided in your code. I can't understand where the problem is. Can you give me some help? Thanks

yuetan031 commented 2 years ago

It seems weird because other people have successfully reproduced the results reported in the paper. You should make sure the hyper parameters and the model architectures are the same with the paper. Furthermore, the output dimension should be set to 10 because there are 10 classes in CIFAR10 dataset.

eepLearning commented 2 years ago

Hi, your excellent work gives us several things to think about.

I am having the same problem that test accuracy is not restored about CIFAR-10. And this gives me some confusion.

One thing I'm confused about is that the code outputs the test accuracy for two cases. Which of the two is the accuracy in the performance table of the paper? ("without proto"/ "with proto")

I cloned your code completely as-is. Therefore, there is no possibility that the code I run is different from the settings in the repository.

"data splitting manner"

Could it be that the problem you mentioned above can occur even though I haven't made any modifications at all?

Also, when looking at the output training log, there seems to be no problem in data splitting.

yuetan031 commented 2 years ago

Hi,

Towards the first question, have you check whether the initial test accuracy on CIFAR10 is about 27%? Pre-trained model weight should be downloaded and loaded appropriately before federated training.

Towards the second question, the test accuracy can be obtained in two manner, we choose to inference from the prototypes, so "with proto" is the one reported in this paper.

yuetan031 commented 2 years ago

Hi, your excellent work gives us several things to think about.

I am having the same problem that test accuracy is not restored about CIFAR-10. And this gives me some confusion.

One thing I'm confused about is that the code outputs the test accuracy for two cases. Which of the two is the accuracy in the performance table of the paper? ("without proto"/ "with proto")

I cloned your code completely as-is. Therefore, there is no possibility that the code I run is different from the settings in the repository.

"data splitting manner"

Could it be that the problem you mentioned above can occur even though I haven't made any modifications at all?

Also, when looking at the output training log, there seems to be no problem in data splitting.

I have rerun the code of this project and it seems that there exists the issue reported by you. Could you please increase the batch size from 4 to 32 to see if the initial accuracy is around 27%?

eepLearning commented 2 years ago

Thank you so much for the quick and kind reply. As you said, changing the batch_size to 32 miraculously doubled the performance.

Final accuracy increased from 30% to 60% Initial accuracy increased from 5% to 12%.

First of all, I am amazed that batch_size can have such a big effect. Is this a data split problem or does it appear to be a training problem in the model network?

The initial accuracy was performed with the global proto obtained in the first round. (Is the initial accuracy you mentioned the inference performance without the proto of the pretrain network?)

Thanks again for your reply.

eepLearning commented 2 years ago

I add the simple code for check the initial accuracy in the following way

    train_loss, train_accuracy = [], []

    for round in tqdm(range(args.rounds)):
        local_weights, local_losses, local_protos = [], [], {}
        print(f'\n | Global Training Round : {round + 1} |\n')

        proto_loss = 0
        for idx in idxs_users:
            local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx])
            w, loss, acc, protos = local_model.update_weights_het(args, idx, global_protos, model=copy.deepcopy(local_model_list[idx]), global_round=round)
            agg_protos = agg_func(protos)

            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss['total']))

            local_protos[idx] = agg_protos
            summary_writer.add_scalar('Train/Loss/user' + str(idx + 1), loss['total'], round)
            summary_writer.add_scalar('Train/Loss1/user' + str(idx + 1), loss['1'], round)
            summary_writer.add_scalar('Train/Loss2/user' + str(idx + 1), loss['2'], round)
            summary_writer.add_scalar('Train/Acc/user' + str(idx + 1), acc, round)
            proto_loss += loss['2']

        # update global weights
        local_weights_list = local_weights
        global_protos = proto_aggregation(local_protos)
        if round == 0:
           print("Initial Accuracy")
            acc_list_l, acc_list_g = test_inference_new_het_lt(args, local_model_list, test_dataset, classes_list,
                                                               user_groups_lt, global_protos)
            print('For all users (with protos), mean of test acc is {:.5f}, std of test acc is {:.5f}'.format(
                np.mean(acc_list_g), np.std(acc_list_g)))
            print('For all users (w/o protos), mean of test acc is {:.5f}, std of test acc is {:.5f}'.format(
                np.mean(acc_list_l), np.std(acc_list_l)))

        for idx in idxs_users:
            local_model = copy.deepcopy(local_model_list[idx])
            local_model.load_state_dict(local_weights_list[idx], strict=True)
            local_model_list[idx] = local_model

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)
yuetan031 commented 2 years ago

Thank you so much for the quick and kind reply. As you said, changing the batch_size to 32 miraculously doubled the performance.

Final accuracy increased from 30% to 60% Initial accuracy increased from 5% to 12%.

First of all, I am amazed that batch_size can have such a big effect. Is this a data split problem or does it appear to be a training problem in the model network?

The initial accuracy was performed with the global proto obtained in the first round. (Is the initial accuracy you mentioned the inference performance without the proto of the pretrain network?)

Thanks again for your reply.

I think the batch_size parameter has such an effect is due to the fact that the local dataset in each client is relatively small (only n classes, k samples/class). But how it actually affects the model may deserve more exploration in the future.

Besides, thanks for your modification on the code. I will follow your suggestion and update this project:)

TsingZ0 commented 2 years ago

I guess the reason may be that the Batch Normalization layer in the ResNet-18 requires a large batch_size to approximate the statistics (\mu and \sigma) accurately.

TsingZ0 commented 2 years ago

Towards the second question, the test accuracy can be obtained in two manner, we choose to inference from the prototypes, so "with proto" is the one reported in this paper.

I cannot find this test metric in the paper. Did I miss something?

yuetan031 commented 2 years ago

Have you checked the "test_inference_new_het_lt" function in ./lib/update.py? It returns acc_list_g and acc_list_l consisting of the accuracy obtained with and without prototypes, which corresponds to the two manners mentioned above.

TsingZ0 commented 2 years ago

Thanks a lot! I have found it.

CityChan commented 2 years ago

I got the same problem, can not reproduce the results of cifar10. Have you solved the problem?

CityChan commented 2 years ago

Hi, As I mentioned in the experimental parts of this paper, a ResNet18 pre-trained on ImageNet is used as the initial model. “For CIFAR10, ResNet18 pre-trained on ImageNet (Krizhevsky, Sutskever, and Hinton 2017) is used as the initial model. The initial average test accuracy of the pre-trained network on CIFAR10 is 27.55%.” Have you followed this setting when you reproduce the results?

Hi, I followed the settings in your code

            resnet = resnet18(args, pretrained=False, num_classes=args.num_classes)
            initial_weight = model_zoo.load_url(model_urls['resnet18'])
            local_model = resnet
            initial_weight_1 = local_model.state_dict()
            for key in initial_weight.keys():
                if key[0:3] == 'fc.' or key[0:5] == 'conv1' or key[
                        0:3] == 'bn1':
                    initial_weight[key] = initial_weight_1[key]

            local_model.load_state_dict(initial_weight)

But the result is still not good, I don't know where the problem is Your work is outstanding and I look forward to your guidance

Thanks for your attention. Have you checked the initial average test accuracy before training? If it is not around 27%, please double check the data splitting manner.

Thanks for your reply, I will try it I have another question. When resnet18 is migrated to cifra10, the final fc layer is modified to output 10 dimensions. Can the random initialization of the fc layer parameters achieve the accuracy of about 27%?

I have the same problem like you. Did you solve the problem? I don't think it reasonable for the randomly initialized model can achieve 27%.

chenrxi commented 2 years ago

it seems that nobody can reproduce the cifar10?

whatement commented 2 years ago

it seems that nobody can reproduce the cifar10?

Have you reproduced the cifar10 experiment results successfully? I repeated it several times on cifar10 and the FedProto method is not significantly better than FedAvg

chenrxi commented 2 years ago

I tired on CIFAR100 on my baseline but only get 1% gain.

ZhangXG001 commented 1 year ago

When using CIFAR10 or other datasets, I do not understand why the Test Average Acc of n=3 is better than the Test Average Acc of n=5 in this paper ? n is to control the heterogeneity , I think ,the smaller n, the greater the heterogeneity , the worse Test Average Acc.

liuyuqinggg commented 1 year ago

I think the reason of this result is that when n = 3,the model only need to learn 3 type of lable, so that the acc is high. because there is no model exchange between client. you can imaging that when n = 1 the acc will more higher than n = 3

keil555 commented 11 months ago

I guess the reason may be that the Batch Normalization layer in the ResNet-18 requires a large batch_size to approximate the statistics (\mu and \sigma) accurately.

Maybe the reason you mentioned affects. But I consider that this phenomenon is also due to the procedures of local update. The following code is a part of update_weights_het:

    for iter in range(self.args.train_ep):
          batch_loss = {'total':[],'1':[], '2':[], '3':[]}
          agg_protos_label = {}
          for batch_idx, (images, label_g) in enumerate(self.trainloader):
              images, labels = images.to(self.device), label_g.to(self.device)
              # loss1: cross-entrophy loss, loss2: proto distance loss
              model.zero_grad()
              log_probs, protos = model(images)
              '''computing loss1 & loss2'''
              loss = loss1 + loss2 * args.ld
              loss.backward()
              optimizer.step()
              for i in range(len(labels)):
                  if label_g[i].item() in agg_protos_label:
                      agg_protos_label[label_g[i].item()].append(protos[i,:])
                  else:
                      agg_protos_label[label_g[i].item()] = [protos[i,:]]

The prototypes obtained from the samples are appended to the list after every batch, and then the local model is updated. However, the prototypes from different batches may vary slightly. Therefore, if the batch size is small (such as 4), the local update frequency is high and the prototypes are more diverse. This may have an impact on the algorithm’s performance.

MichaelLee-ceo commented 11 months ago

Did anyone successfully reproduce the results of CIFAR-10? I clone the project and execute the command shown in the readme.

python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 110 --train_shots_max 110 --ld 0.1 --local_bs 32

And I get about only 55% of testing accuracy.

keil555 commented 11 months ago

Did anyone successfully reproduce the results of CIFAR-10? I clone the project and execute the command shown in the readme.

python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 110 --train_shots_max 110 --ld 0.1 --local_bs 32

And I get about only 55% of testing accuracy.

I got a similar result as you.

wangjie-123 commented 4 months ago

it seems that nobody can reproduce the cifar10?

chenzhigang00 commented 3 months ago

Modify the hyper-parameter might raise the test accuracy e.g python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --num_users 20 --ways 3 --shots 100 --stdev 2 --rounds 110 --train_shots_max 110 --ld 1 --local_bs 32 can reach 66% (change the regularization weight param $\lambda$ a.k.a --ld from 0.1 to 1)

sammens commented 1 week ago

It might be a little late, but I got a test accuracy of 74.14% on the cifar10 data using the example below; python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --local_bs 32 --num_channels 3 --ld 0.1 --rounds 110 --ways 3 --num_users 20 --shot 100 --stdev 1 --lr 0.01 Remember to change the number of channels to 3 as CIFAR10 data has three input channels (RGB).