WwZzz / easyFL

An experimental platform for federated learning.
Apache License 2.0
519 stars 88 forks source link

使用fashion_mnist数据集进行fedmgda+训练有bug #59

Open szzzhy opened 6 months ago

szzzhy commented 6 months ago

首先赞美大佬给予好用的联邦学习框架赐福!

我在用fedmgda+算法进行fashion_mnist任务训练时出现了error,具体如下:

捕获

其中fedmgda+算法是根据大佬的教程复制粘贴过去的,没有有什么改动。数据分布是每个client只有一类数据,如图:

![Uploading 捕获.PNG…]()

另外其他参数设置是 option_batch_size_10 = {'learning_rate': 0.01, 'num_steps': 1, 'num_rounds': 500, 'gpu': 1, 'batch_size': 10, 'proportion':0.1, 'seed': 0}

经过之前一系列测试,是经过标准化(gi.normalize())函数后出现了nan值,应该是标准化除以0了。

希望大佬早日修好bug,在做算法实验了所以比较急。

最后再次赞美大佬!

szzzhy commented 6 months ago

数据分布设置如下也出下了同样的error: task_config_dirichlet = {'benchmark': femnist, 'partitioner': { 'name': 'DirichletPartitioner', 'para': { 'num_clients': 100, 'alpha': 0.1}} }

大概在300多轮出现的 image

WwZzz commented 6 months ago

测试了下发现是除以0导致的。为了不影响正常训练,可以把gi.normalize()那里归一化的方式替换成以下形式

        for i in range(len(grads)):
            gi_norm = 0.0
            for p in grads[i].parameters():
                gi_norm += (p**2).sum()
            grads[i] = grads[i]/(torch.sqrt(gi_norm) + 1e-8)

image

修改后我这里在提到的第一个设置下运行500轮无报错。