shaoxiongji / federated-learning

A PyTorch Implementation of Federated Learning http://doi.org/10.5281/zenodo.4321561
http://doi.org/10.5281/zenodo.4321561
MIT License
1.29k stars 372 forks source link

模型聚合这个步骤感觉和FedAvg原文上描述的不一样 #13

Closed downing19 closed 4 years ago

downing19 commented 4 years ago

最近参考大佬您的这个代码学习联邦学习,偶然发现一点令我疑惑的地方。原文中每一个global epoch会随机指定所有clients中的一个fraction进行更新(并不是所有clients都参与更新),聚合的时候原文描述的是所有clients的模型都进行聚合,即没有参与更新的clients的模型也都会参与平均。而代码中的聚合步骤只考虑了参与更新的clients的模型平均。请问代码是不是有问题,还是我的理解错误呢?

for iter in range(args.epochs):
    w_locals, loss_locals = [], []
    m = max(int(args.frac * args.num_users), 1)
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)
    for idx in idxs_users:
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
        w_locals.append(copy.deepcopy(w))
        loss_locals.append(copy.deepcopy(loss))
    # update global weights
    w_glob = FedAvg(w_locals)

    # copy weight to net_glob
    net_glob.load_state_dict(w_glob)
shaoxiongji commented 4 years ago

你的理解没错,这里跟原文不一样,大佬感兴趣的话可以PR

shaoxiongji commented 4 years ago

fixed https://github.com/shaoxiongji/federated-learning/commit/885c93e90eb23cf5a350cb3330b81304ca50a5b1