ki-ljl / FedProx-PyTorch

PyTorch implementation of FedProx (Federated Optimization for Heterogeneous Networks, MLSys 2020).
MIT License
93 stars 18 forks source link

A question on the proximal term #1

Open Hanc1999 opened 2 years ago

Hanc1999 commented 2 years ago

I remember in Li's paper the proximal term is defined to be the quadratic norm-2 of the (w - w^t), but in the client.py the proximal term is implemented as the sum of the norm-2s.

Screenshot on the paper 'Federated Optimization in heterogeneous networks':

image

JonOnEarth commented 1 year ago

I think you are right, the code is not what paper prosposed.

vaseline555 commented 1 year ago

Isn't it equivalent on the computation graph when taking derivatives? FYI, see implementation (https://flower.dev/docs/apiref-flwr.html#server-strategy-fedprox) of flower package. It seems same as the one in this repo.

Hanc1999 commented 1 year ago

Isn't it equivalent on the computation graph when taking derivatives? FYI, see implementation (https://flower.dev/docs/apiref-flwr.html#server-strategy-fedprox) of flower package. It seems same as the one in this repo.

I think they are different. Say W = {W1, W2}, the derivative of norm2(W-Wt) on W1 relates to W2, while derivative of norm2(W1-Wt1)+norm2(W2-Wt2) does not.

vaseline555 commented 1 year ago

@Hanc1999 what do you mean W={W1, W2}? Does it mean W is a whole model parameter, followingly W1 is for a weight of 1st layer, and W2 is for a weight of 2nd layer? I presume that you may consider the case where norm2([W1;W2], Wt) yields an interaction term, saying W1*W2, right? (correct me if I understood incorrectly)

But it should be noted that the norm-squared term is calculated layer-wise by iterating through generator (parameters()), not calculated at once. You can find in the code of this repository https://github.com/ki-ljl/FedProx-PyTorch/blob/ec79be9ac3b849a36f4a1c44ed59d5569d3f4eaa/client.py#L64

You can also find the same implementation from other sources, e.g., the code of flower package I linked before (https://flower.dev/docs/apiref-flwr.html#server-strategy-fedprox) , FedNTD (NeurIPS22; https://github.com/Lee-Gihun/FedNTD/blob/be00ee598139654abe650c446a7ba3bc4e340233/algorithms/fedprox/ClientTrainer.py#L64), and FedBABU (ICLR22; https://github.com/jhoon-oh/FedBABU/blob/e34c15b4288224f0af1bc13670bc0cdc373c9576/models/Update.py#L302)

Could you please let me know if I got the wrong point? Thank you!

Hanc1999 commented 1 year ago

Hi, @vaseline555, yes for the first part what I mean for W is the set of different layers' weights. In this repo's code it do calculates different layers‘ proximal term separately and sum them up as the one in the loss function (CASE-B). However if my memory is correct I think in Li's paper all layers' weights should be calculated once together, which will cause a different derivative (CASE-A). Therefore I think the implementation is different.

CASE-A: image

CASE-B: image

vaseline555 commented 1 year ago

@Hanc1999 Thank you for your clarification. I cannot find the content you mentioned in FedProx paper... could you let me know if you have in mind?

Pertaining to your CASE-A scenario: No, I think we don't have to consider the interaction from other layers like in CASE-A, since each neuron has its own graidents independently (i.e., element-wise multiplication and summation) by the backpropagation rule, unless we explicitly considered special interaction operation between neurons.

In fact, an important thing both in original regularization term and its derivative term is the norm of difference between w and w_t, right? If the network structure is, for example, (784-100-10), then w and w_t consist of 784x100 and 100x10 matrices (assume there are no bias terms). Since it is impossible to note norm of difference with these shape inconsistencies, let us flatten and concatenate them as (784x100 + 100x10) x 1 vectors. Now we can think of the norm of differcne between (vecotrised) w and w_t.

As I stated in the first paragraph, the backpropagation is happened element-wise manner, we can also separately calculate the gradient of subgroup of neurons, which can be corresponded to different layers. That's why there's no problem in existing implementations calculating layer-wise norm differences.

Well, I think even if it IS the CASE-A, its difference is just a matter of scale from CASE-B, which can be resolved by the regularisation techinques, learning rate scheduling and a tuning of constant \mu (note that it can also be adaptively tuned as stated in the FedProx paper).

Hope this helps. Thank you!

Best, Adam

Hanc1999 commented 1 year ago

independently

Hi Adam, thanks for the reply and here's a quick response to the last point. Yes I think for the gradients generated in case A and B, the difference is the scalar. But I think the most important is to check whether the convergence proof of FedProx still stands in this new situation, otherwise there will be a big problem. For now I didn't check though the math stuff yet, if you have any findings please publish here and let us know. Thanks.

Regards, Yimin