svinota / pyroute2

Python Netlink and PF_ROUTE library — network configuration and monitoring
https://pyroute2.org/
Other
931 stars 243 forks source link

python 3.4+ asyncio support #466

Open rjarry opened 6 years ago

rjarry commented 6 years ago

Hi,

If I understand correctly, the only way to make pyroute2 behave with asyncio for now is to run ipdb in an executor thread. This makes it hard to have interaction with non blocking code outside of the thread.

I wonder if that would be possible to add support for non blocking I/O using asyncio coroutines. I'm not sure if all low level socket API is available with asyncio sockets.

https://docs.python.org/3/library/asyncio-eventloop.html#creating-connections

I might try to have a shot at this, but I'd like some guidance on what has to be changed. I saw a lot of threading locks and other blocking synchronization routines which do not live well with asyncio. Maybe this would require splitting some parts of the code and have 2 versions of IPRoute and IPDB: the standard one and a asyncio non blocking one with coroutines.

vodik commented 6 years ago

I would really like to see this too, personally.

Passing the netlink socket directly to create_connection would indeed work, and as I understand the library, there's already some logic to make things kinda act like futures already as things get filled in as data arrives on different threads.

The problem I'd imagine would be to add the necessary abstractions the code base to support both...

It might make sense to look and see if concurrent.futures can be used first. That can glue better into asyncio.

rjarry commented 6 years ago

I am not optimistic about making the same codebase work with both async and bocking code.

There are a lot of parts that could be shared (netlink parsing most of all). But all socket operations must use the asyncio versions of the objects (which return coroutines). And these absolutely make no sense when there is no event loop.

I guess the most straight forward implementation would subclaass the NetlinkSocket IPRoute and IPDB classes to support async code (removing all threading.* stuff and anything blocking the interpreter).

svinota commented 6 years ago

A big and complicated task, but I think it must be done in some time, as the OpenStack uses async as well and we have to comply with that somehow.

svinota commented 6 years ago

I'm going to start a project within pyroute2 to make a proof-of-concept using Python 3.7+ AIO capabilities.

If you have any objections agains Python 3.7+, e.g. if you prefer eventlet etc., pls let me know.

rjarry commented 6 years ago

Hi @svinota,

About 3.7+, it looks not really necessary. You could stick to 3.5+ as the async def syntax was added here. 3.6 and 3.7 do not add major improvements.

I have a very basic implementation of asyncio socket. It works for basic operations. I didn't test the full scope.

# Copyright 2013-2018 the pyroute2 authors
# Copyright 2018 6WIND S.A.

from socket import SOCK_RAW
from socket import SOL_SOCKET
from socket import SO_RCVBUF
from socket import SO_SNDBUF
from socket import socket
import asyncio
import errno
import logging
import os
import random
import struct
import time

from pyroute2.config import AF_NETLINK
from pyroute2.netlink import NETLINK_ADD_MEMBERSHIP
from pyroute2.netlink import NETLINK_DROP_MEMBERSHIP
from pyroute2.netlink import NLMSG_DONE
from pyroute2.netlink import NLMSG_ERROR
from pyroute2.netlink import NLM_F_DUMP
from pyroute2.netlink import NLM_F_MULTI
from pyroute2.netlink import NLM_F_REQUEST
from pyroute2.netlink import SOL_NETLINK
from pyroute2.netlink import nlmsg
from pyroute2.netlink.exceptions import NetlinkDecodeError
from pyroute2.netlink.exceptions import NetlinkError
from pyroute2.netlink.exceptions import NetlinkHeaderDecodeError

#------------------------------------------------------------------------------
class AioNetlinkSocket(object):

    def __init__(self, error_handler=None):
        if error_handler and not asyncio.iscoroutinefunction(error_handler):
            raise TypeError('error_handler must be a coroutine')
        self.error_handler = error_handler
        self.sock = None
        self.lock = asyncio.Lock()
        self.data_queue = asyncio.Queue()
        self.msg_queue = {}
        self.recv_task = None
        self.callbacks = {}
        self.port = None
        self.pid = None
        self.seq = 0
        self.log = logging.getLogger(
            'aionetlink.%s.%x' % (self.__class__.__name__, id(self)))

    # global list of available ports for this process
    free_ports = list(range(1024))
    # netlink proto, must be specified by subclasses
    proto = None

    async def connect(self):
        """
        Open the socket and bind to it. Automatically using a free port.
        Schedule :meth:`.recv_loop` in the background.
        """
        if self.sock is not None:
            await self.close()
        self.sock = socket(AF_NETLINK, SOCK_RAW, self.proto)
        # Non blocking sockets are mandatory for asyncio.
        self.sock.setblocking(False)
        # Use big buffers to avoid ENOBUF errors with netlink storms.
        self.sock.setsockopt(SOL_SOCKET, SO_SNDBUF, 1024 * 1024)
        self.sock.setsockopt(SOL_SOCKET, SO_RCVBUF, 1024 * 1024)
        self.seq = int(time.monotonic())

        bad_ports = []
        try:
            # Try 5 times to find a free port to bind the socket, then give up.
            while True:
                i = random.randrange(len(AioNetlinkSocket.free_ports))
                port = AioNetlinkSocket.free_ports.pop(i)
                pid = (os.getpid() & 0x3fffff) + (port << 22)
                try:
                    self.sock.bind((pid, 0))
                    self.port = port
                    self.pid = pid
                    self.log.debug('connect: using port %d', port)
                    break
                except OSError as e:
                    bad_ports.append(port)
                    if e.errno == errno.EADDRINUSE and len(bad_ports) < 5:
                        self.log.debug('connect: port %d already used', port)
                        continue
                    raise
        finally:
            AioNetlinkSocket.free_ports.extend(bad_ports)

        # Start the receive loop in the background.
        self.recv_task = asyncio.ensure_future(self.recv_loop())

    async def close(self):
        """
        Stop the :meth:`.recv_loop` task and close the socket.
        Free the used port.
        """
        try:
            if self.recv_task is not None:
                self.recv_task.cancel()
                await self.recv_task
                self.recv_task = None
        finally:
            if self.port is not None:
                AioNetlinkSocket.free_ports.append(self.port)
            if self.sock is not None:
                self.sock.close()
                self.sock = None

    def add_callback(self, msg_type, callback):
        """
        Register a callback for the given message type. Callbacks are only
        invoked for multicast messages (see :meth:`.subscribe`).

        :arg int msg_type:
            The type of netlink message.
        :arg <coro> callback:
            Coroutine that will be called/awaited upon the reception of all
            messages of msg_type with the message as only parameter. The return
            value and any exception are ignored.
        """
        if not asyncio.iscoroutinefunction(callback):
            raise TypeError('callback must be a coroutine')
        self.callbacks.setdefault(msg_type, []).append(callback)

    def del_callback(self, msg_type, callback):
        """
        Unregister a previously registered callback.
        """
        self.callbacks[msg_type].remove(callback)

    def reset_callbacks(self, msg_type=None):
        """
        Unregister all callbacks for a given type.

        :arg int msg_type:
            The type of netlink message for which to unregister all callbacks.
            If unset, callbacks for all types are unregistered.
        """
        if msg_type is not None:
            del self.callbacks[msg_type]
        else:
            self.callbacks.clear()

    def subscribe(self, *groups):
        """
        Subscribe to multicast messages of the given groups. 
        """
        for grp in groups:
            self.sock.setsockopt(SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, grp)

    def unsubscribe(self, *groups):
        """
        Unsubscribe from multicast messages of the given groups.
        """
        for grp in groups:
            self.sock.setsockopt(SOL_NETLINK, NETLINK_DROP_MEMBERSHIP, grp)

    async def request(self, msg, msg_type, msg_flags=NLM_F_REQUEST | NLM_F_DUMP,
                      terminate=None, callback=None, *args, **kw):
        """
        Send message and wait for a response from the kernel.

        :arg nlmsg/dict msg:
            The message to send to the kernel. If the message is a dict,
            `msg_type` will be used to determine which subclass of :cls:`.nlmsg`
            should be used to encode the message.
        :arg int msg_type:
            The type of message (will be put in the message header).
        :arg int flags:
            Message flags (will be put in the message header).
        :arg <coroutine> terminate:
            Optional coroutine to determine if the request sequence is "done".
            By default, the sequence stops on the first NLMSG_DONE message.
        :arg <coroutine> callback:
            Optional coroutine. Invoked for each parsed message.

        :returns:
            A list of :cls:`.nlmsg` objects. If `msg_callback` is set, the list
            will be empty.
        """
        seq = self.next_sequence_number()
        await self.put(seq, msg, msg_type, msg_flags)
        return await self.get(seq, terminate, callback)

    nlm_request = request  # compat with pyroute2 api

    async def put(self, seq, msg, msg_type, msg_flags=NLM_F_REQUEST):
        """
        Send a message to the kernel. Normally, you should not have to call
        this directly, it is called by :meth:`.request` by specifying a unique
        sequence number.

        :arg int seq:
            The sequence number to write in the message header. The same
            sequence number **must** be used in :meth:`.get` to make sure to
            receive the responses.
        :arg nlmsg/dict msg:
            The message to send to the kernel. If the message is a dict,
            `msg_type` will be used to determine which subclass of :cls:`.nlmsg`
            should be used to encode the message.
        :arg int msg_type:
            The type of message (will be put in the message header).
        :arg int flags:
            Message flags (will be put in the message header).
        """
        if seq != 0:
            async with self.lock:
                if seq in self.msg_queue:
                    raise ValueError('duplicate sequence number: %s' % seq)
                self.msg_queue[seq] = asyncio.Queue()

        if not isinstance(msg, nlmsg):
            msg = self.msg_types[msg_type](msg)
        msg['header']['type'] = msg_type
        msg['header']['flags'] = msg_flags
        msg['header']['pid'] = self.pid
        msg['header']['sequence_number'] = seq
        self.log.debug('put: sending message: %s', msg)
        msg.reset()
        msg.encode()
        self.log.debug('put: raw message: %s', msg.data)
        await self.sendall(msg.data)

    async def get(self, seq, terminate=None, callback=None):
        """
        Read responses from the kernel. Normally, you should not have to call
        this directly, it is called by :meth:`.request` by specifying the
        unique sequence number that was used to send the request.

        :arg int seq:
            The request sequence number.
        :arg <coroutine> terminate:
            Optional coroutine to determine if the request sequence is "done".
            By default, the sequence stops on the first NLMSG_DONE message.
        :arg <coroutine> callback:
            Optional coroutine. Invoked for each parsed message.

        :returns:
            A list of :cls:`.nlmsg` objects. If `msg_callback` is set, the list
            will be empty.
        """
        msg_queue = self.msg_queue.get(seq)
        if msg_queue is None:
            raise ValueError('unknown sequence number: %s' % seq)

        if terminate is None:
            async def terminate(msg):
                return msg['header']['type'] == NLMSG_DONE
        elif not asyncio.iscoroutinefunction(terminate):
            raise TypeError('terminate must be a coroutine')

        if callback and not asyncio.iscoroutinefunction(callback):
            raise TypeError('callback must be a coroutine')

        msgs = []
        try:
            done = False
            while not done:
                msg = await msg_queue.get()
                if isinstance(msg, Exception):
                    raise msg
                done = await terminate(msg)
                if not done:
                    if callback is not None:
                        await callback(msg)
                    else:
                        msgs.append(msg)
                    if not msg['header']['flags'] & NLM_F_MULTI:
                        # Not a multipart message, stop here.
                        done = True
        finally:
            async with self.lock:
                if seq in self.msg_queue:
                    del self.msg_queue[seq]

        return msgs

    async def recv_loop(self):
        """
        Receive loop (started by :meth:`.connect`).

        It waits for messages on the socket and parses them. After which:

        - If the sequence number is 0 (multicast message), execute all
          registered callbacks for the message's type. If there are no
          registered callback, drop the message.
        - If the sequence number is valid (i.e., part of an ongoing request),
          put the message in the correct message queue. Otherwise drop the
          message.

        If an error occurs (socket or parse error), call `self.error_handler`
        with the error as parameter. If `self.error_handler` is unspecified,
        log the exception.
        """
        loop = asyncio.get_event_loop()
        self.log.debug('recv_loop: starting')
        while True:
            try:
                data = await loop.sock_recv(self.sock, 65536)
                self.log.debug('recv_loop: got data: %s', data)

                for msg in self.parse(data):
                    self.log.debug('recv_loop: parsed msg: %s', msg)
                    seq = msg['header']['sequence_number']
                    msg_queue = self.msg_queue.get(seq)
                    if msg_queue:
                        # reply to a request, put the message in the proper queue
                        await msg_queue.put(msg)

                    else:
                        # multicast message, apply callbacks for this message type
                        callbacks = self.callbacks.get(msg['header']['type'], [])
                        futs = [callback(msg) for callback in callbacks]
                        res = await asyncio.gather(*futs, return_exceptions=True)
                        for i, result in enumerate(res):
                            if not isinstance(result, Exception):
                                continue
                            callback = callbacks[i]
                            self.log.error('recv_loop: error in callback %s',
                                           callback, exc_info=result)

            except asyncio.CancelledError:
                self.log.debug('recv_loop: stopping')
                break

            except Exception as e:
                # notify everyone of the error
                futs = [q.put(e) for q in tuple(self.msg_queue.values())]
                await asyncio.gather(*futs, return_exceptions=True)

                if self.error_handler is not None:
                    await self.error_handler(self, e)
                else:
                    self.log.exception('recv_loop: unhandled error')

    # mapping of message types to message classes
    # must be specifed by subclasses
    msg_types = {}
    msg_type_offset = 4
    msg_type_format = 'H'

    def parse(self, data):
        """
        Parse netlink messages from a byte array.

        :returns:
            An iterator yielding :cls:`.nlmsg` objects.
        """
        offset = 0
        # there must be at least one header in the buffer,
        # 'IHHII' == 16 bytes
        end = len(data) - 16
        while offset <= end:
            # pick type and length
            length, = struct.unpack_from('I', data, offset)
            if length == 0:
                break
            error = None
            msg_type, = struct.unpack_from(self.msg_type_format, data,
                                           offset + self.msg_type_offset)
            if msg_type == NLMSG_ERROR:
                code, = struct.unpack_from('i', data, offset + 16)
                if code != 0:
                    error = NetlinkError(abs(code))

            msg_class = self.msg_types.get(msg_type, nlmsg)
            msg = msg_class(data, offset=offset)

            try:
                msg.decode()
                msg['header']['error'] = error
                # try to decode encapsulated error message
                if error is not None:
                    enc_type, = struct.unpack_from('H', data, offset + 24)
                    enc_class = self.msg_types.get(enc_type, nlmsg)
                    enc = enc_class(data, offset=offset + 20)
                    enc.decode()
                    msg['header']['errmsg'] = enc
            except NetlinkHeaderDecodeError as e:
                # in the case of header decoding error,
                # create an empty message
                msg = nlmsg()
                msg['header']['error'] = e
            except NetlinkDecodeError as e:
                msg['header']['error'] = e

            yield msg

            offset += msg.length

    def next_sequence_number(self):
        # To avoid confusion with multicast messages. The request sequence
        # numbers must NOT be zero. It is encoded in a uint32 field in the
        # messages, when it overflows, wrap it to 1.
        self.seq += 1
        if self.seq > 0xffffffff:
            self.seq = 1
        return self.seq

    async def sendall(self, buf):
        """
        Non blocking socket.sendall.
        """
        loop = asyncio.get_event_loop()
        return await loop.sock_sendall(self.sock, buf)

    async def __aenter__(self):
        await self.connect()
        return self

    async def __aexit__(self, exc_type, exc_value, traceback):
        await self.close()
rjarry commented 6 years ago

btw, I have inspired from libnl for the connect() method.

celebdor commented 6 years ago

I agree with 3.5

On Fri, Mar 23, 2018 at 1:09 PM, Robin Jarry notifications@github.com wrote:

btw, I have inspired from libnl https://github.com/thom311/libnl/blob/master/lib/socket.c#L66 for the connect() method.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/svinota/pyroute2/issues/466#issuecomment-375645664, or mute the thread https://github.com/notifications/unsubscribe-auth/AAlzn2avpyPFXiZInUsFvmN7NP-EB3xXks5thOXpgaJpZM4Sb4_o .

asteven commented 5 years ago

If you could re-factor out the protocol parts as an sans-io protocol library then it could be used by anybody in python land.

I have no idea how much work that would be or if it's possible at all without rewriting everything. Just thought I'd point it out.

Personally I would enjoy if pyroute2 would work with trio.

smurfix commented 3 years ago

I have written a library which accomplishes this semi-transparently.

https://pypi.org/project/aevent/

aevent currently works with Trio instead of asyncio, but (a) that is fixable and (b) trio is the better choice anyway. ;-)

Caveat 2: Version 0.0.6 works for me and passes a lot of pyroute2's test cases, but it's far from finished. Help very welcome.

h-khalili commented 4 months ago

Hi @svinota Any updates on whether this is being implemented at all?

svinota commented 3 months ago

An asyncio compatible branch is started (again :) )

I don't lose hopes to fix it one day, since I want to get rid of these locks-based code in the socket wrapper.

rjarry commented 3 months ago

Hi folks, for the record, it should be noted that Netlink sockets are a bit different from standard TCP sockets. asyncio code expects that a non-blocking socket will never block.

Due to the way Netlink works, the sendmsg() syscall may not return EWOULDBLOCK (the socket buffer is not full) but may block anyway, even if the socket was set to non-blocking. The kernel will not return to user space until the operation is complete. For GET operations, this is instantaneous, but for DEL, NEW and SET, it may take a while.

There is no way around this since the return code of the system call needs to report the operation status. I got bitten by this against my will :)

The way I had found was to call sendmsg() using a thread pool to avoid blocking the main thread.

async def sendmsg(sock: socket.socket, buf: bytes):
    loop = asyncio.get_event_loop()
    while True:
        try:
            return await loop.run_in_executor(None, sock.sendmsg, [buf])
        except (BlockingIOError, InterruptedError):
            writeable = asyncio.Event()
            loop.add_writer(sock.fileno(), writeable.set)
            try:
                await writeable.wait()
            finally:
                loop.remove_writer(sock.fileno())
svinota commented 3 weeks ago

Finally I'm going to announce that the project core is migrating to asyncio right now.

The branch

https://github.com/svinota/pyroute2/tree/thread-unsafe

The status

will be updated here:

NetlinkSocket and async API

New init parameter: loop New methods: async_get(); planned async_put() Legacy sync method get() now evaluates the result and returns a tuple → to be changed, will return a generator again

NetNS

The child proxy process is removed completely. The netns-bound socket is being sent back to the parent process. Thus netns bound sockets API becomes generic to any socket that inherits from CoreSocket. NetNS is being preserved only for complatibility, it is not required anymore.

NDB

Is being rewritten, much work, so tough.

Deprecated & dropped

9p2000 support

Added to the core, all internal IPC will migrate to 9p2000