kubeedge / ianvs

Distributed Synergy AI Benchmarking
https://ianvs.readthedocs.io
Apache License 2.0
115 stars 46 forks source link

add: add new proposal of Federated Incremental Learning for Label Sca… #124

Closed Yoda-wu closed 2 months ago

Yoda-wu commented 4 months ago

…rcity: Base on KubeEdge-Ianvs

Federated Incremental Learning for Label Scarcity: Base on KubeEdge-Ianvs proposal What type of PR is this?

/kind: design

What this PR does / why we need it: The PR is a proposal to add a new benchmarking paradigm——federated class-incremental learning paradigm Which issue(s) this PR fixes:

https://github.com/kubeedge/ianvs/issues/97 Fixes #

p.s. There is another proposal we have to discuss which one is better: https://github.com/Yoda-wu/ianvs/blob/new_proposal/docs/proposals/algorithms/federated-class-incremental-learning/Federated%20Class-Incremental%20and%20Semi-Supervised%20learning%20Proposal.md

MooreZheng commented 3 months ago

@hsj576 might also need to take a look at this proposal

hsj576 commented 3 months ago

Currently, Ianvs does not support Federated Learning, so it may be necessary to consider how to implement federated learning in Ianvs before implement Federated Incremental Learning in this proposal.

Yoda-wu commented 3 months ago

Good to see the proposal. The architecture format looks great to me. This version is using a single node instead of multiple nodes simulation.

Remain works include:

  1. The procedures include federated learning but have not yet cooperated with Sedna federated learning lib. Need to run a single-node version federated learning using Sedna lib. Marked new modules in the architecutre.
  2. The forget rate looks like sample-wise instead of task-wise, thus it is different against FWT and BWT. Need to describe the forget rate correctly and add a formular.

Currently, Ianvs does not support Federated Learning, so it may be necessary to consider how to implement federated learning in Ianvs before implement Federated Incremental Learning in this proposal.

Thanks for the advice! The new proposal have been updated and marked the new modules in the architecture,formular of forget rate also added. I believe that federated class incremental learning is a special kind of federated learning the different is that the data process by client is changeable, So the new modules in this architecture is optional, if user choose not to use the new module then this architecture will be a simple federated learning paradigm.

Yoda-wu commented 3 months ago

Good to see the proposal. The architecture format looks great to me. This version is using a single node instead of multiple nodes simulation. Remain works include:

  1. The procedures include federated learning but have not yet cooperated with Sedna federated learning lib. Need to run a single-node version federated learning using Sedna lib. Marked new modules in the architecutre.
  2. The forget rate looks like sample-wise instead of task-wise, thus it is different against FWT and BWT. Need to describe the forget rate correctly and add a formular.

Currently, Ianvs does not support Federated Learning, so it may be necessary to consider how to implement federated learning in Ianvs before implement Federated Incremental Learning in this proposal.

Thanks for the advice! The new proposal have been updated and marked the new modules in the architecture,formular of forget rate also added. I believe that federated class incremental learning is a special kind of federated learning the different is that the data process by client is changeable, So the new modules in this architecture is optional, if user choose not to use the new module then this architecture will be a simple federated learning paradigm.

i wrote a simple demo of my design at https://github.com/Yoda-wu/ianvs/blob/dev_script/core/testcasecontroller/algorithm/paradigm/federated_learning/federeated_learning.py and the example: https://github.com/Yoda-wu/ianvs/tree/dev_script/examples/federated-learning/fedavg and the following is the runing result: image

Since the formula does not display properly in github, here I post a picture: image

Yoda-wu commented 3 months ago

Sicne 8.22 weekly meeting is suspended, i would like to discuss the implement detail in github. To implement the Paradigm of federated learning and federated class incremental learning in Ianvs, I propose the following two approaches:

Sedna Lib

Sedna Federated Learning Lib

The core class of Federated learning in Sedna is FederatedLearning:

class FederatedLearning(JobBase):

In FederatedLearning, there are three main components: __init__, register and train. Corresponding to class initialization, registration function, and training function.

init(self, estimator, aggregation="FedAvg")

source code The init function is very short and all it does is get the aggregate server's ip and port from the context. Then we initialize the aggregator (this turns out to be unused, as the aggregator is used in the server). The federated learning client is registered to the aggregation server through the register function.

    def __init__(self, estimator, aggregation="FedAvg"):

        protocol = Context.get_parameters("AGG_PROTOCOL", "ws")
        agg_ip = Context.get_parameters("AGG_IP", "127.0.0.1")
        agg_port = int(Context.get_parameters("AGG_PORT", "7363"))
        agg_uri = f"{protocol}://{agg_ip}:{agg_port}/{aggregation}"
        config = dict(
            protocol=protocol,
            agg_ip=agg_ip,
            agg_port=agg_port,
            agg_uri=agg_uri
        )
        super(FederatedLearning, self).__init__(
            estimator=estimator, config=config)
        self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation)

        connect_timeout = int(Context.get_parameters("CONNECT_TIMEOUT", "300"))
        self.node = None
        self.register(timeout=connect_timeout)

register(self, timeout=300)

source code

register function mainly to create the sedna.core.client.AggregationClient and use it to communicate with aggregation server. register also instantiate the estimator which is the main object to train the model.

    def register(self, timeout=300):
        """
        Deprecated, Client proactively subscribes to the aggregation service.

        Parameters
        ----------
        timeout: int, connect timeout. Default: 300
        """
        self.log.info(
            f"Node {self.worker_name} connect to : {self.config.agg_uri}")
        self.node = AggregationClient(
            url=self.config.agg_uri,
            client_id=self.worker_name,
            ping_timeout=timeout
        )

        FileOps.clean_folder([self.config.model_url], clean=False)
        self.aggregation = self.aggregation()
        self.log.info(f"{self.worker_name} model prepared")
        if callable(self.estimator):
            self.estimator = self.estimator()

train(self, train_data, valid_data=None, post_process=None, **kwargs)

source The train function is the main functional function of FederatedLearning, which receives training data and validation data as parameters

Firstly, it need to initialize some local variable:

    def train(self, train_data,
              valid_data=None,
              post_process=None,
              **kwargs):
        callback_func = None
        if post_process:
            callback_func = ClassFactory.get_cls(
                ClassType.CALLBACK, post_process)

        round_number = 0
        num_samples = len(train_data)
        _flag = True
        start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        res = None

Then is the training loop, Sedna implement it with a while true loop, and the exit logic is left to the server. At the beginning of each round, if the client has received the information returned by the server in the previous round, it can continue to train locally. The logic for local training is implemented by estimator. After training, the training results are sent to the server through the node object initialized by register. And node receives the messages from the server.

        while 1:
            if _flag:
                round_number += 1
                start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
                self.log.info(
                    f"Federated learning start, round_number={round_number}")
                res = self.estimator.train(
                    train_data=train_data, valid_data=valid_data, **kwargs)

                current_weights = self.estimator.get_weights()
                send_data = {"num_samples": num_samples,
                             "weights": current_weights}
                self.node.send(
                    send_data, msg_type="update_weight", job_name=self.job_name
                )
            received = self.node.recv(wait_data_type="recv_weight")
            if not received:
                _flag = False
                continue
            _flag = True

sedna then collates the messages received from the server, logs them, and aggregates parameters using estimator. As you can see, if the received exit_flag is ok then the training loop is over.

            rec_data = received.get("data", {})
            exit_flag = rec_data.get("exit_flag", "")
            server_round = int(rec_data.get("round_number"))
            total_size = int(rec_data.get("total_sample"))
            self.log.info(
                f"Federated learning recv weight, "
                f"round: {server_round}, total_sample: {total_size}"
            )
            n_weight = rec_data.get("weights")
            self.estimator.set_weights(n_weight)
            task_info = {
                'currentRound': round_number,
                'sampleCount': total_size,
                'startTime': start,
                'updateTime': time.strftime(
                    "%Y-%m-%d %H:%M:%S", time.localtime())
            }
            model_paths = self.estimator.save()
            task_info_res = self.estimator.model_info(
                model_paths, result=res, relpath=self.config.data_path_prefix)
            if exit_flag == "ok":
                self.report_task_info(
                    task_info,
                    K8sResourceKindStatus.COMPLETED.value,
                    task_info_res)
                self.log.info(f"exit training from [{self.worker_name}]")
                return callback_func(
                    self.estimator) if callback_func else self.estimator
            else:
                self.report_task_info(
                    task_info,
                    K8sResourceKindStatus.RUNNING.value,
                    task_info_res)

That's all the functionalities of FederatedLearning, you can see that this section is Sedna's definition of client side in federated learning, then let's introduce the functionalities of Server side provided by Sedna

Sedna Aggregate Server Lib

source code The core class of the Sedna aggregation server is as follows:

class Aggregator(WSServerBase):

It consists of __init__, async send_message, and exit_check. These are initialization, asynchronous sending and receiving of messages (in this case, aggregation when received from all clients), and end training detection.

init(self, **kwargs)

The initialization is to instantiate the aggregation algorithm implemented by the user and initialize some training parameters such as the number of exit rounds (total number of training rounds), the number of participating clients, and the current training round.

    def __init__(self, **kwargs):
        super(Aggregator, self).__init__()
        self.exit_round = int(kwargs.get("exit_round", 3))
        aggregation = kwargs.get("aggregation", "FedAvg")
        self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation)
        if callable(self.aggregation):
            self.aggregation = self.aggregation()
        self.participants_count = int(kwargs.get("participants_count", "1"))
        self.current_round = 0

async def send_message(self, client_id: str, msg: Dict)

The main functions of sending and receiving messages, Server and Client in Sedna mainly communicate through websocket protocol, so here we use an asynchronous method to achieve. As you can see, the main thing the source code does is collate the information from the clients, and if the number of clients has been aggregated is reached, the aggregate function in the aggregation is called, and the aggregated parameters are returned to the user.

 async def send_message(self, client_id: str, msg: Dict):
        data = msg.get("data")
        if data and msg.get("type", "") == "update_weight":
            info = AggClient()
            info.num_samples = int(data["num_samples"])
            info.weights = data["weights"]
            self._client_meta[client_id].info = info
            current_clinets = [
                x.info for x in self._client_meta.values() if x.info
            ]
            # exit while aggregation job is NOT start
            if len(current_clinets) < self.participants_count:
                return
            self.current_round += 1
            weights = self.aggregation.aggregate(current_clinets)
            exit_flag = "ok" if self.exit_check() else "continue"

            msg["type"] = "recv_weight"
            msg["round_number"] = self.current_round
            msg["data"] = {
                "total_sample": self.aggregation.total_size,
                "round_number": self.current_round,
                "weights": weights,
                "exit_flag": exit_flag
            }
        for to_client, websocket in self._clients.items():
            try:
                await websocket.send_json(msg)
            except Exception as err:
                LOGGER.error(err)
            else:
                if msg["type"] == "recv_weight":
                    self._client_meta[to_client].info = None

exit_check

This function is to check if current_round > exit_round

Problem

These two components are provided by Sedna to the user through the ClassFactory interface, and the user implements estimator and aggregation to complete the function of federated learning according to the requirements.

Sedna hides the details of communication from the user, that is, you only need to focus on implementing the local training function in estimator and the aggregation function in aggregation.

So we can do the same in Ianvs, we can omit the communication, and directly collect and aggregate parameters by the program. Essentially using estimator and aggregation. However, instead of communicating with WebSockets, we can do it directly in memory.

for example in algorithm when build_paradigm:

from core.testcasecontroller.algorithm.paradigm import  FederatedClassIncrementalLearning
if self.paradigm_type == ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value:
    return FederatedClassIncrementalLearning(workspace, **config)

and the FederatedClassIncrementalLearning:

class FederatedClassIncrementalLearning(FederatedLearning):

    def __init__(self, workspace, **kwargs):
        super(FederatedClassIncrementalLearning, self).__init__(workspace, **kwargs)
        self.rounds = kwargs.get("incremental_rounds", 1)
        self.task_size = kwargs.get("task_size", 10)
        self.system_metric_info = {}
        self.lock = RLock()
        self.aggregate_clients=[]
        self.train_infos=[]
        self.aggregation, self.aggregator = self.module_instances.get(ModuleType.AGGREGATION.value)
    def init_client(self):
        import copy
        tempalte = self.build_paradigm_job(ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value)
        self.clients = [copy.deepcopy(tempalte) for _ in range(self.task_size)]

Here we can directly access the aggregator and estimator as server and client. And the training process is :

    def run(self):
        self.init_client()
        dataset_files = self._split_dataset(self.task_size)
        for r in range(self.rounds):
            task_id = r // self.task_size
            LOGGER.info(f"Round {r} task id: {task_id}")
            train_datasets = self.task_definition(dataset_files, task_id)
            self._train(train_datasets, task_id=task_id, round=r, task_size=self.task_size)
            global_weights = self.aggregator.aggregate(self.aggregate_clients)
            if hasattr(self.aggregator, "helper_function"):
                self.helper_function(self.train_infos)
            self.send_weights_to_clients(global_weights)
            self.aggregate_clients.clear()
            self.train_infos.clear()
        test_res = self.predict(self.dataset.test_url)
        return test_res, self.system_metric_info

Pros

Which plan should we adopt? If there are any other options, please feel free to suggest them.

Yoda-wu commented 2 months ago

The formula is clarified. Now the acc is a per task / per class metric, instead of a per sample one. That is revised in the proposal now.

Two suggestions:

  1. As for the new federated learning scheme in ianvs, Plan B is fine to me, which adds merely the algorithm part of sedna into ianvs, while Plan A introduces the server functions which might not be necessary for all ianvs users. A general federated learning scheme needs to be added into the ianvs core. Examples also need the advanced versions, i.e., the web socket version of federated learning and the incremental learning version of federated learning. Then the architecture can be revised accordingly.
  2. Another great contribution at the routine meeting is that the web socket issue mentioned by Yoda-wu, where sedna lacks a heartbeat message to ensure consistent network connection for federated learning in Web Socket. An issue could be raised in Sedna.

Thanks for the suggestions! I update the prorposal and raise an issue in sedna: https://github.com/kubeedge/sedna/issues/442

Yoda-wu commented 2 months ago

Great to see the updated version. As mentioned, the architecture can be revised accordingly. It will be appreciated if the developed part can be highlighted in the architecture.

  1. A general federated learning scheme needs to be added into the ianvs core (within controller).
  2. (The directory of ) Examples also need the advanced versions, i.e., the web socket version of federated learning and the incremental learning version of federated learning.

Thanks for the advice! I have updated the proposal:

  1. i highlighted the developed part in the architecture:

    • For federetaed learning scheme: image
    • For federated class incremental learning scheme: image
  2. the web socket version of federated learning and the incremental learning version of federated learning examples are corresponding to the section 3.5.3 and section 3.5.2 image image

Yoda-wu commented 2 months ago

After weekly meeting, the architecture for federated learning paradigm will be shown as follow: image And the ferederated class incremental learning paradigm will be shown as follow: image

MooreZheng commented 2 months ago

/lgtm

kubeedge-bot commented 2 months ago

[APPROVALNOTIFIER] This PR is APPROVED

This pull-request has been approved by: hsj576, MooreZheng

The full list of commands accepted by this bot can be found here.

The pull request process is described here

Needs approval from an approver in each of these files: - ~~[OWNERS](https://github.com/kubeedge/ianvs/blob/main/OWNERS)~~ [MooreZheng] Approvers can indicate their approval by writing `/approve` in a comment Approvers can cancel approval by writing `/approve cancel` in a comment