adap / flower

Flower: A Friendly Federated AI Framework
https://flower.ai
Apache License 2.0
5.02k stars 862 forks source link

Connection overriding problem #2105

Open saurav935 opened 1 year ago

saurav935 commented 1 year ago

Describe the bug

I am working with Flower which is a federated learning framework. In its grpc connection file they are only creating 1 channel whereas I want 2-3 channels. But when I created 1 more channel with server_address localhost:5040, the previous channel with server address localhost:8080 is getting overridden. How can I avoid that and use both the channels?

# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contextmanager for a gRPC streaming channel to the Flower server."""

from contextlib import contextmanager
from logging import DEBUG
from pathlib import Path
from queue import Queue
from typing import Callable, Iterator, Optional, Tuple, Union

from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.proto.transport_pb2_grpc import FlowerServiceStub

# The following flags can be uncommented for debugging. Other possible values:
# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
# import os
# os.environ["GRPC_VERBOSITY"] = "debug"
# os.environ["GRPC_TRACE"] = "tcp,http"

def on_channel_state_change(channel_connectivity: str) -> None:
    """Log channel connectivity."""
    log(DEBUG, channel_connectivity)

@contextmanager
def grpc_connection(
    server_address: str,
    max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
    root_certificates: Optional[Union[bytes, str]] = None,
) -> Iterator[Tuple[Callable[[], ServerMessage], Callable[[ClientMessage], None]]]:
    """Establish a gRPC connection to a gRPC server.

    Parameters
    ----------
    server_address : str
        The IPv4 or IPv6 address of the server. If the Flower server runs on the same
        machine on port 8080, then `server_address` would be `"[0.0.0.0:8080](http://0.0.0.0:8080/)"` or
        `"[::]:8080"`.
    max_message_length : int
        The maximum length of gRPC messages that can be exchanged with the Flower
        server. The default should be sufficient for most models. Users who train
        very large models might need to increase this value. Note that the Flower
        server needs to be started with the same value
        (see `flwr.server.start_server`), otherwise it will not know about the
        increased limit and block larger messages.
        (default: 536_870_912, this equals 512MB)
    root_certificates : Optional[bytes] (default: None)
        The PEM-encoded root certificates as a byte string or a path string.
        If provided, a secure connection using the certificates will be
        established to an SSL-enabled Flower server.

    Returns
    -------
    receive, send : Callable, Callable

    Examples
    --------
    Establishing a SSL-enabled connection to the server:

    >>> from pathlib import Path
    >>> with grpc_connection(
    >>>     server_address,
    >>>     max_message_length=max_message_length,
    >>>     root_certificates=Path("/crts/root.pem").read_bytes(),
    >>> ) as conn:
    >>>     receive, send = conn
    >>>     server_message = receive()
    >>>     # do something here
    >>>     send(client_message)
    """
    if isinstance(root_certificates, str):
        root_certificates = Path(root_certificates).read_bytes()

    channel = create_channel(
        server_address='localhost:8080',
        root_certificates=root_certificates,
        max_message_length=max_message_length,
    )
    channel.subscribe(on_channel_state_change)

    queue: Queue[ClientMessage] = Queue(  # pylint: disable=unsubscriptable-object
        maxsize=1
    )
    stub = FlowerServiceStub(channel)

    server_message_iterator: Iterator[ServerMessage] = stub.Join(iter(queue.get, None))

# Adding one more channel over here. They are getting over-riden that is only the next one (localhost:5040)
#####################

    channel_2 = create_channel(
        server_address='localhost:5040',
        root_certificates=root_certificates,
        max_message_length=max_message_length,
    )
    channel_2.subscribe(on_channel_state_change)

    queue: Queue[ClientMessage] = Queue(  # pylint: disable=unsubscriptable-object
        maxsize=1
    )
    stub = FlowerServiceStub(channel_2)

    server_message_iterator: Iterator[ServerMessage] = stub.Join(iter(queue.get, None))

#####################

    def receive() -> ServerMessage:
        return next(server_message_iterator)

    def send(msg: ClientMessage) -> None:
        return queue.put(msg, block=False)

    try:
        yield (receive, send)
    finally:
        # Make sure to have a final
        channel.close()
        log(DEBUG, "gRPC channel closed")

Steps/Code to Reproduce

The code to reproduce the error is mentioned above.

Expected Results

I expect to see the connections to both the servers without getting overridden : )

Actual Results

Flower working successfully with multiple servers instead of just 1 server.

danieljanes commented 1 year ago

Thanks for reaching out @saurav935. Before discussing the solution, could you elaborate more on the requirements?

Is the intent to connect to two FL servers simultaneously? In that case, it might be an option to call start_client twice from two different threads (since start_client is blocking).

saurav935 commented 1 year ago

Thanks for your response @danieljanes. I know that the multithreading approach would work, but I am also exploring an approach without using multithreading. Basically I want to do the federated learning using 3 servers instead of 1 server. Calling start_client function twice will also do the training twice and I don't want that. Basically I want the 3 servers to send the information, I will perform an operation with the data received from the 3 servers, then do the training once (the handle function), and send the trained result to all the 3 servers. So, receive from 3 servers, perform an operation, train once, and send to all the 3 servers.

For example:

receive_from_all_3_servers()

perform_operation()

train_once() # handle function

send_to_all_3_servers()
danieljanes commented 1 year ago

Thanks, that's helpful context. Is this in the context of a research project or is this system intended for production?

For research/prototyping, the multithreading option might still be the easiest way to do it (not the cleanest, just the easiest). You could have all three threads running, they get a message from the server in fit, they all put their message in a shared data structure and wait until all three messages are available, then only one of them does the training and writes the results back to the shared data structure, and, once the result is available, all three send the result back to their respective server.

saurav935 commented 1 year ago

Thanks! I knew that the flow of multithreading would be like that, but I am currently working on how to do it without multithreading. I tried adding more Join functions in the transport_pb2_grpc.py file and it worked.

# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2

class FlowerServiceStub(object):
    """Missing associated documentation comment in .proto file."""

    def __init__(self, channel, channel_1, channel_2):
        """Constructor.

        Args:
            channel: A grpc.Channel.
        """
        # Join method for 1st server        
        self.Join = channel.stream_stream(
                '/flwr.proto.FlowerService/Join',
                request_serializer=flwr_dot_proto_dot_transport__pb2.ClientMessage.SerializeToString,
                response_deserializer=flwr_dot_proto_dot_transport__pb2.ServerMessage.FromString,
                )
        # Join method for 2nd server        
        self.Join_1 = channel_1.stream_stream(
                '/flwr.proto.FlowerService/Join',
                request_serializer=flwr_dot_proto_dot_transport__pb2.ClientMessage.SerializeToString,
                response_deserializer=flwr_dot_proto_dot_transport__pb2.ServerMessage.FromString,
                )

        # Join method for 3rd server
        self.Join_2 = channel_2.stream_stream(
                '/flwr.proto.FlowerService/Join',
                request_serializer=flwr_dot_proto_dot_transport__pb2.ClientMessage.SerializeToString,
                response_deserializer=flwr_dot_proto_dot_transport__pb2.ServerMessage.FromString,
                )

class FlowerServiceServicer(object):
    """Missing associated documentation comment in .proto file."""

    def Join(self, request_iterator, context):
        """Missing associated documentation comment in .proto file."""
        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
        context.set_details('Method not implemented!')
        raise NotImplementedError('Method not implemented!')

def add_FlowerServiceServicer_to_server(servicer, server):
    rpc_method_handlers = {
            'Join': grpc.stream_stream_rpc_method_handler(
                    servicer.Join,
                    request_deserializer=flwr_dot_proto_dot_transport__pb2.ClientMessage.FromString,
                    response_serializer=flwr_dot_proto_dot_transport__pb2.ServerMessage.SerializeToString,
            ),
    }
    generic_handler = grpc.method_handlers_generic_handler(
            'flwr.proto.FlowerService', rpc_method_handlers)
    server.add_generic_rpc_handlers((generic_handler,))

 # This class is part of an EXPERIMENTAL API.
class FlowerService(object):
    """Missing associated documentation comment in .proto file."""

    @staticmethod
    def Join(request_iterator,
            target,
            options=(),
            channel_credentials=None,
            call_credentials=None,
            insecure=False,
            compression=None,
            wait_for_ready=None,
            timeout=None,
            metadata=None):
        print("\ninside Join\n")
        return grpc.experimental.stream_stream(request_iterator, target, '/flwr.proto.FlowerService/Join',
            flwr_dot_proto_dot_transport__pb2.ClientMessage.SerializeToString,
            flwr_dot_proto_dot_transport__pb2.ServerMessage.FromString,
            options, channel_credentials,
            insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

I also made changes in the connections.py file by adding the server addresses of other servers.