FedML-AI / FedML

FEDML - The unified and scalable ML library for large-scale distributed training, model serving, and federated learning. FEDML Launch, a cross-cloud scheduler, further enables running any AI jobs on any GPU cloud or on-premise cluster. Built on this library, TensorOpera AI (https://TensorOpera.ai) is your generative AI platform at scale.
https://TensorOpera.ai
Apache License 2.0
4.19k stars 786 forks source link

Problem with the function " _local_test_on_all_clients" in "https://github.com/FedML-AI/FedML/blob/master/python/fedml/simulation/sp/fedavg/fedavg_api.py" #1578

Open shubham22124 opened 1 year ago

shubham22124 commented 1 year ago

def _local_test_on_all_clients(self, round_idx):

    logging.info("################local_test_on_all_clients : {}".format(round_idx))

    train_metrics = {"num_samples": [], "num_correct": [], "losses": []}

    test_metrics = {"num_samples": [], "num_correct": [], "losses": []}

    **client = self.client_list[0]**

    for client_idx in range(self.args.client_num_in_total):
        """
        Note: for datasets like "fed_CIFAR100" and "fed_shakespheare",
        the training client number is larger than the testing client number
        """
        if self.test_data_local_dict[client_idx] is None:
            continue
        client.update_local_dataset(
            0,
            self.train_data_local_dict[client_idx],
            self.test_data_local_dict[client_idx],
            self.train_data_local_num_dict[client_idx],
        )
        # train data
        train_local_metrics = client.local_test(False)
        train_metrics["num_samples"].append(copy.deepcopy(train_local_metrics["test_total"]))
        train_metrics["num_correct"].append(copy.deepcopy(train_local_metrics["test_correct"]))
        train_metrics["losses"].append(copy.deepcopy(train_local_metrics["test_loss"]))

        # test data
        test_local_metrics = client.local_test(True)
        test_metrics["num_samples"].append(copy.deepcopy(test_local_metrics["test_total"]))
        test_metrics["num_correct"].append(copy.deepcopy(test_local_metrics["test_correct"]))
        test_metrics["losses"].append(copy.deepcopy(test_local_metrics["test_loss"]))

    # test on training dataset
    train_acc = sum(train_metrics["num_correct"]) / sum(train_metrics["num_samples"])
    train_loss = sum(train_metrics["losses"]) / sum(train_metrics["num_samples"])

    # test on test dataset
    test_acc = sum(test_metrics["num_correct"]) / sum(test_metrics["num_samples"])
    test_loss = sum(test_metrics["losses"]) / sum(test_metrics["num_samples"])

    stats = {"training_acc": train_acc, "training_loss": train_loss}
    if self.args.enable_wandb:
        wandb.log({"Train/Acc": train_acc, "round": round_idx})
        wandb.log({"Train/Loss": train_loss, "round": round_idx})

    mlops.log({"Train/Acc": train_acc, "round": round_idx})
    mlops.log({"Train/Loss": train_loss, "round": round_idx})
    logging.info(stats)

    stats = {"test_acc": test_acc, "test_loss": test_loss}
    if self.args.enable_wandb:
        wandb.log({"Test/Acc": test_acc, "round": round_idx})
        wandb.log({"Test/Loss": test_loss, "round": round_idx})

    mlops.log({"Test/Acc": test_acc, "round": round_idx})
    mlops.log({"Test/Loss": test_loss, "round": round_idx})
    logging.info(stats)

In the 4th line of the function, why is always the zeroth client selected? This way, the testing happens on the model corresponding to the zeroth client only, but we want the average test error on the local dataset for each client, isn't it?

fedml-dimitris commented 1 year ago

@shubham22124 Thank you for asking this question. However, in that line, we just get the general client state (e.g., model) from the first client. The evaluation still happens across all clients (see client_idx) as shown in line 195: https://github.com/FedML-AI/FedML/blob/master/python/fedml/simulation/sp/fedavg/fedavg_api.py#L195

shubham22124 commented 1 year ago

But line 195 just updates the dataset. Shouldn't the model be updated as well, as each client undergoes local training and has a different model than the model received from the server?

fedml-dimitris commented 1 year ago

@shubham22124 So basically, lines 193-198 is where the global model is being evaluated against the local dataset of each client, so every client's model is the same, hence the client = self.client_list[0]. In other words, the evaluation of the global model is rotated to each client's dataset.