litian96 / FedProx

Federated Optimization in Heterogeneous Networks (MLSys '20)
MIT License
650 stars 158 forks source link

Should the global model replace the client model? #25

Closed wnma3mz closed 3 years ago

wnma3mz commented 3 years ago

Hi, I read your paper and code, and this work has inspired me a lot in my work on Federated Learning Optimization. I am trying to reproduce FedProx using PyTorch and I am confused on a small detail. In the algorithm in the paper, the local client model seems to have no replacement operation, i.e. w_k^t=w^t$

image

But when I read your code, I found that there is actually a REPLACE operation.

https://github.com/litian96/FedProx/blob/d2a4501f319f1594b732d88315c5ca1a72855f50/flearn/trainers/fedprox.py#L77-L78

And I also found similar operations in a PyTorch replication repo's.

FedMA

https://github.com/IBM/FedMA/blob/4b586a5a22002dc955d025b890bc632daa3c01c7/main.py#L863-L883

Q1: Actually, should I use this aggregated model to replace the local client model after aggregation?

Q2: When not replacing, it can be interpreted as the local model $w_k^t$ trying to approximate the global model $w^t$. From another point of view, does it count to alleviate the catastrophic forgetting problem?

If I have misunderstood something, please let me know. I look forward to hearing from you.

litian96 commented 3 years ago

Q1. w^t is the global model we would like to learn (and output). In theory, for fedprox, each device doesn't necessarily need to start solving h_k from w^t. All we need is w_k^{t+1} is a \gamma_k^t-inexact minimizer of h_k. FedProx allows to use any local algorithm to solve h_k. In practice (the code), we let each device start from the current global model w^t and run multiple local iterations of mini-batch SGD (this is one way [and a natural way] of obtaining an inexact solution), which will recover the algorithm of FedAvg when \mu=0, and will recover distributed SGD when the number of local iterations is 1.

Q2. I am not sure what catastrophic forgetting means in federated settings. Usually, people consider it in continual learning where they learn a sequence of tasks. While in federated learning, we simultaneously learn from multiple devices/tasks to produce a single global model or personalized models. But I think (variants) the proximal term could be used in other problems as well.

wnma3mz commented 3 years ago

Thank you for your detailed reply, I have gained a lot from it. I look forward to what you do next.