TsingZ0 / PFLlib

37 traditional FL (tFL) or personalized FL (pFL) algorithms, 3 scenarios, and 20 datasets.
GNU General Public License v2.0
1.35k stars 283 forks source link

关于SCAFFOLD算法的代码问题 #160

Closed YrChenM closed 7 months ago

YrChenM commented 8 months ago

自己在实际跑的过程中发现几轮后loss就爆炸了无法收敛,可能是代码存在问题。两个函数应修改为 def update_yc(self): for ci, c, x, yi in zip(self.client_c, self.global_c, self.global_model.parameters(), self.model.parameters()): ci.data = ci - c + 1 / self.num_batches / self.local_epochs / self.learning_rate (x - yi) 和 def delta_yc(self): delta_y = [] delta_c = [] for c, x, yi in zip(self.global_c, self.global_model.parameters(), self.model.parameters()): delta_y.append(yi - x) delta_c.append(- c + 1 / self.num_batches / self.local_epochs / self.learning_rate (x - yi)) return delta_y, delta_c 就是关于delta_c的更新,分母项根据原论文应该是局部学习率乘上局部迭代次数,而局部迭代次数不是num_batches(这只是一个local_epoch中迭代所有batch的次数),而是num_batches * local_epochs. 修改过后似乎不会出现loss爆炸的问题了。 以及参数聚合时,原文中是考虑客户端不存在权重的问题所以各个客户端的参数是等权合成,但实际发现带权重weight算法也work,适用于各个客户端样本数不相同的情形,所以可以修改为 def aggregate_parameters(self):

original version

    # for dy, dc in zip(self.delta_ys, self.delta_cs):
    #     for server_param, client_param in zip(self.global_model.parameters(), dy):
    #         server_param.data += client_param.data.clone() / self.num_join_clients * self.server_learning_rate
    #     for server_param, client_param in zip(self.global_c, dc):
    #         server_param.data += client_param.data.clone() / self.num_clients

    # save GPU memory
    global_model = copy.deepcopy(self.global_model)
    global_c = copy.deepcopy(self.global_c)
    for cid,w in zip(self.uploaded_ids,self.uploaded_weights):
        dy, dc = self.clients[cid].delta_yc()
        for server_param, client_param in zip(global_model.parameters(), dy):
            server_param.data += client_param.data.clone() * w * self.server_learning_rate
        for server_param, client_param in zip(global_c, dc):
            server_param.data += client_param.data.clone() * w
    self.global_model = global_model
    self.global_c = global_c
TsingZ0 commented 8 months ago

关于num_batches * local_epochs我在新代码里做了修改。因为之前我做的SCAFFOLD相关实验默认local_epochs=1,所以这里就疏忽了。多谢指出

至于第二个聚合参数的问题,这样采用带权聚合的修改岂不是不符合SCAFFOLD原文的算法了?