omarfoq / FedEM

Official code for "Federated Multi-Task Learning under a Mixture of Distributions" (NeurIPS'21)
Apache License 2.0
154 stars 28 forks source link

Problem with the 'mix' method #3

Closed NookLook2014 closed 2 years ago

NookLook2014 commented 2 years ago

In th aggregator.py, for the method "mix" in class CentralizedAggregator(Aggregator). I found that the average_learners method takes all clients as input rather than the sampled clients. Is that reasonable? Since clients are sampled each round, i believe it should be better to average thosed sampled clients, then update the averaged global model to all clients.

def mix(self):
    self.sample_clients()

    for client in self.sampled_clients:
        client.step()

    for learner_id, learner in enumerate(self.global_learners_ensemble):
        learners = [client.learners_ensemble[learner_id] for client in **self.clients**]
        average_learners(learners, learner, weights=self.clients_weights)
omarfoq commented 2 years ago

Hello,

Thank you for your question. I think this issue is similar to #1. Please let me know if #1 does not answer your question.

I also want to add that directly performing the weighted average of the sampled clients with the initial weights is not correct in general, instead you need to reweight the clients as done in (Li et al., 2019, Table~1) for example.

I hope this answers your question. Please let me know if I am missing something.

NookLook2014 commented 2 years ago

Ok, thanks for your reply