hyperium / tonic

A native gRPC client & server implementation with async/await support.
https://docs.rs/tonic
MIT License
10.05k stars 1.02k forks source link

How to implement an async read/write incoming stream with tokio named pipe? #1518

Open leiless opened 1 year ago

leiless commented 1 year ago

Hi all.

I'd like to know how could I implement an incoming stream with async read/write tokio::net::windows::named_pipe?

I'm new to tokio and tonic, and my use case requires me to use Windows named pipe as the gRPC backend transport.

I know the named pipe stream has to implement the futures::stream::Stream trait (i.e. the poll_next method), but I have no idea of how to write the working code.

Do you guys have any hint for this?

leiless commented 1 year ago

Hi, guys. I've implemented the Windows AF_UNIX domain socket. But in order to bring UDS to Windows, I have to modify the tokio and mio code (by exporting some internal struct/functions).

Here is my Windows UDS implementation:

uds_listener.rs

use std::fmt::Formatter;
use std::path::Path;

use mio::{Interest, Registry, Token};
use windows::Win32::Networking::WinSock;

pub struct SocketWrapper(pub WinSock::SOCKET);

impl std::os::windows::io::AsRawSocket for SocketWrapper {
    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
        self.0.0 as _
    }
}

impl std::os::windows::io::IntoRawSocket for SocketWrapper {
    fn into_raw_socket(self) -> std::os::windows::io::RawSocket {
        self.0.0 as _
    }
}

impl std::os::windows::io::FromRawSocket for SocketWrapper {
    unsafe fn from_raw_socket(sock: std::os::windows::io::RawSocket) -> Self {
        Self(WinSock::SOCKET(sock as _))
    }
}

// Implementation references
//  https://doc.rust-lang.org/std/os/unix/net/struct.UnixListener.html
//  https://github.com/tokio-rs/mio/blob/master/src/sys/unix/uds/listener.rs
pub struct UnixListener {
    io: mio::io_source::IoSource<SocketWrapper>,
}

impl UnixListener {
    #[inline(always)]
    // https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/ms632663(v=vs.85)
    fn make_word(lo: u8, hi: u8) -> u16 {
        lo as u16 | ((hi as u16) << 8)
    }

    #[inline(always)]
    fn wsa_get_last_error() -> i32 {
        let last_error = unsafe { WinSock::WSAGetLastError().0 };
        assert_ne!(last_error, 0);
        last_error
    }

    #[inline]
    fn socket_path_to_sun_path(socket_path: &str) -> [u8; 108] {
        let slice = socket_path.as_bytes();

        let mut array = [0; 108];
        let length = socket_path.len();
        array[..length].copy_from_slice(&slice[..length]);
        array[length] = 0; // Mark the end
        array
    }

    pub fn bind<P: AsRef<Path>>(path: P, backlog: i32) -> anyhow::Result<Self> {
        unsafe {
            let wsa_ver = Self::make_word(2, 2);
            let mut wsa_data = WinSock::WSADATA::default();
            let rc = WinSock::WSAStartup(
                wsa_ver,
                std::ptr::addr_of_mut!(wsa_data),
            );
            if rc != 0 {
                return Err(anyhow::anyhow!("WSAStartup() error {}", Self::wsa_get_last_error()));
            }
            assert_eq!(wsa_data.wVersion, wsa_ver);

            let socket_fd = WinSock::socket(
                WinSock::AF_UNIX as _,
                WinSock::SOCK_STREAM,
                WinSock::IPPROTO_TCP.0,
            );
            if socket_fd == WinSock::INVALID_SOCKET {
                return Err(anyhow::anyhow!("socket() error {}", Self::wsa_get_last_error()));
            }

            let socket_path = path.as_ref().as_os_str().to_str().ok_or_else(
                || anyhow::anyhow!("{:?} is not a valid UTF-8 path", path.as_ref())
            )?;
            if socket_path.len() >= 108 {
                return Err(anyhow::anyhow!("socket path length {} is too long", socket_path.len()));
            }
            let sun_path = Self::socket_path_to_sun_path(socket_path);
            let socket_addr = WinSock::SOCKADDR_UN {
                sun_family: WinSock::ADDRESS_FAMILY(WinSock::AF_UNIX),
                sun_path,
            };

            let rc = WinSock::bind(
                socket_fd,
                std::ptr::addr_of!(socket_addr) as _,
                std::mem::size_of::<WinSock::SOCKADDR_UN>() as _,
            );
            if rc == WinSock::SOCKET_ERROR {
                return Err(anyhow::anyhow!("bind() error {}", Self::wsa_get_last_error()));
            }

            let rc = WinSock::listen(socket_fd, backlog);
            if rc == WinSock::SOCKET_ERROR {
                return Err(anyhow::anyhow!("listen() error {}", Self::wsa_get_last_error()));
            }

            Ok(Self {
                io: mio::io_source::IoSource::new(SocketWrapper(socket_fd))
            })
        }
    }

    #[inline(always)]
    fn to_socket(&self) -> WinSock::SOCKET {
        WinSock::SOCKET(self.io.0.0)
    }

    pub fn accept(&self) -> std::io::Result<(crate::uds_stream::UnixStream, WinSock::SOCKADDR_UN)> {
        unsafe {
            let mut addr = WinSock::SOCKADDR_UN::default();
            let mut addr_len = std::mem::size_of::<WinSock::SOCKADDR_UN>() as i32;
            let client_fd = WinSock::accept(
                self.to_socket(),
                Some(std::ptr::addr_of_mut!(addr) as _),
                Some(std::ptr::addr_of_mut!(addr_len)),
            );
            if client_fd == WinSock::INVALID_SOCKET {
                return Err(std::io::Error::from_raw_os_error(Self::wsa_get_last_error()));
            }

            // Make the client fd as non-blocking
            let mut i_mode = 1;
            let rc = WinSock::ioctlsocket(
                client_fd,
                WinSock::FIONBIO,
                std::ptr::addr_of_mut!(i_mode),
            );
            if rc == WinSock::SOCKET_ERROR {
                let last_error_code = Self::wsa_get_last_error();

                let rc = WinSock::closesocket(client_fd);
                if rc == WinSock::SOCKET_ERROR {
                    eprintln!("closesocket() {} error: {}", client_fd.0, Self::wsa_get_last_error());
                }

                return Err(std::io::Error::from_raw_os_error(last_error_code));
            }

            Ok((crate::uds_stream::UnixStream::from_socket(client_fd)?, addr))
        }
    }

    pub fn local_addr(&self) -> anyhow::Result<WinSock::SOCKADDR_UN> {
        let mut addr = WinSock::SOCKADDR_UN::default();
        let mut addr_len = std::mem::size_of::<WinSock::SOCKADDR_UN>() as i32;
        unsafe {
            let rc = WinSock::getsockname(
                self.to_socket(),
                std::ptr::addr_of_mut!(addr) as _,
                std::ptr::addr_of_mut!(addr_len),
            );
            if rc == WinSock::SOCKET_ERROR {
                return Err(anyhow::anyhow!("getsockname() error {}", Self::wsa_get_last_error()));
            }
        }
        Ok(addr)
    }
}

impl mio::event::Source for UnixListener {
    fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> {
        self.io.register(registry, token, interests)
    }

    fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> {
        self.io.reregister(registry, token, interests)
    }

    fn deregister(&mut self, registry: &Registry) -> std::io::Result<()> {
        self.io.deregister(registry)
    }
}

impl std::fmt::Debug for UnixListener {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        let mut builder = f.debug_struct("UnixListener");
        builder.field("fd", &self.io.0);
        if let Ok(addr) = self.local_addr() {
            builder.field("local", &addr);
        }
        builder.finish()
    }
}

uds_stream.rs

use std::io::{Error, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};
use mio::{Interest, Registry, Token};
use tokio::io::ReadBuf;
use windows::Win32::Networking::WinSock;

struct IoSourcedSocket {
    io: mio::io_source::IoSource<crate::uds_listener::SocketWrapper>,
}

#[inline(always)]
fn wsa_get_last_error() -> i32 {
    let last_error = unsafe { WinSock::WSAGetLastError().0 };
    assert_ne!(last_error, 0);
    last_error
}

impl IoSourcedSocket {
    fn from_socket(socket: WinSock::SOCKET) -> Self {
        Self {
            io: mio::io_source::IoSource::new(
                crate::uds_listener::SocketWrapper(socket)
            ),
        }
    }
}

impl std::io::Read for IoSourcedSocket {
    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
        <&Self as std::io::Read>::read(&mut &*self, buf)
    }
}

impl std::io::Write for IoSourcedSocket {
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        <&Self as std::io::Write>::write(&mut &*self, buf)
    }

    fn flush(&mut self) -> std::io::Result<()> {
        <&Self as std::io::Write>::flush(&mut &*self)
    }
}

// https://github.com/tokio-rs/mio/blob/master/src/sys/windows/named_pipe.rs#L510
impl<'a> std::io::Read for &'a IoSourcedSocket {
    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
        unsafe {
            let rc = WinSock::recv(self.io.0, buf, WinSock::SEND_RECV_FLAGS(0));
            if rc == WinSock::SOCKET_ERROR {
                Err(Error::from_raw_os_error(wsa_get_last_error()))
            } else {
                Ok(rc as _)
            }
        }
    }
}

// https://github.com/tokio-rs/mio/blob/master/src/sys/windows/named_pipe.rs#L561
impl<'a> std::io::Write for &'a IoSourcedSocket {
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        unsafe {
            let rc = WinSock::send(self.io.0, buf, WinSock::SEND_RECV_FLAGS(0));
            if rc == WinSock::SOCKET_ERROR {
                Err(Error::from_raw_os_error(wsa_get_last_error()))
            } else {
                Ok(rc as _)
            }
        }
    }

    fn flush(&mut self) -> std::io::Result<()> {
        // There is no flush / sync on Windows socket 2
        Ok(())
    }
}

impl mio::event::Source for IoSourcedSocket {
    fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> {
        self.io.register(registry, token, interests)
    }

    fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> {
        self.io.reregister(registry, token, interests)
    }

    fn deregister(&mut self, registry: &Registry) -> std::io::Result<()> {
        self.io.deregister(registry)
    }
}

pub struct UnixStream {
    io: tokio::io::poll_evented::PollEvented<IoSourcedSocket>,
}

impl UnixStream {
    pub fn from_socket(socket: WinSock::SOCKET) -> std::io::Result<Self> {
        Ok(Self {
            io: tokio::io::poll_evented::PollEvented::new(
                IoSourcedSocket::from_socket(socket)
            )?,
        })
    }

    fn local_addr(&self) -> anyhow::Result<WinSock::SOCKADDR_UN> {
        let mut addr = WinSock::SOCKADDR_UN::default();
        let mut addr_len = std::mem::size_of::<WinSock::SOCKADDR_UN>() as i32;
        unsafe {
            let rc = WinSock::getsockname(
                self.io.io.0,
                std::ptr::addr_of_mut!(addr) as _,
                std::ptr::addr_of_mut!(addr_len),
            );
            if rc == WinSock::SOCKET_ERROR {
                return Err(anyhow::anyhow!("getsockname() error {}", wsa_get_last_error()));
            }
        }
        Ok(addr)
    }
}

impl tonic::transport::server::Connected for UnixStream {
    type ConnectInfo = Option<WinSock::SOCKADDR_UN>;

    fn connect_info(&self) -> Self::ConnectInfo {
        match self.local_addr() {
            Ok(sock_addr) => {
                Some(sock_addr)
            }
            Err(_err) => {
                None
            }
        }
    }
}

impl tokio::io::AsyncRead for UnixStream {
    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
        unsafe { self.io.poll_read(cx, buf) }
    }
}

impl tokio::io::AsyncWrite for UnixStream {
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
        self.io.poll_write(cx, buf)
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
        Poll::Ready(Ok(()))
    }

    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
        Poll::Ready(Ok(()))
    }

    fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize, Error>> {
        self.io.poll_write_vectored(cx, bufs)
    }
}

uds.rs

use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use windows::Win32::Networking::WinSock;

pub struct UnixListener {
    io: tokio::io::poll_evented::PollEvented<crate::uds_listener::UnixListener>,
}

impl UnixListener {
    pub fn bind<P: AsRef<Path>>(path: P, backlog: i32) -> anyhow::Result<Self> {
        let listener = crate::uds_listener::UnixListener::bind(path, backlog)?;
        let io = tokio::io::poll_evented::PollEvented::new(listener)?;
        Ok(Self {
            io
        })
    }

    fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<std::io::Result<(crate::uds_stream::UnixStream, WinSock::SOCKADDR_UN)>> {
        let (stream, addr) = futures::ready!(self.io.registration().poll_read_io(cx, || self.io.accept()))?;
        Poll::Ready(Ok((stream, addr)))
    }
}

pub struct UnixListenerStream {
    inner: UnixListener,
}

impl UnixListenerStream {
    pub fn new(listener: UnixListener) -> Self {
        Self { inner: listener }
    }

    pub fn into_inner(self) -> UnixListener {
        self.inner
    }
}

impl futures::stream::Stream for UnixListenerStream {
    type Item = std::io::Result<crate::uds_stream::UnixStream>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match futures::ready!(self.inner.poll_accept(cx)) {
            Ok((stream, _sock_addr)) => {
                Poll::Ready(Some(Ok(stream)))
            }
            Err(err) => {
                Poll::Ready(Some(Err(err)))
            }
        }
    }
}

The code cannot compile due to some tonic, tokio version incompatibilities(e.g. rustc complains the AsyncRead/AsyncWrite trait is not implemented on UnixStream but it's actually implemented), and I don't know why.

Anyone can help with this?

How can I use custom tokio/mio (with some internal struct/funcs exported) along the tonic to get this working? The main difficulty here is tonic depends on tokio, and I changed the tonic-tokio to my own tokio crate, but it still cannot compile.

pronebird commented 9 months ago

I'd like to know how to wrap named pipe into stream too. For UDS UnixListenerStream exists but I don't see anything for named pipe.

Chaoses-Ib commented 5 months ago

In https://github.com/tokio-rs/tokio/issues/6591#issuecomment-2134427618, the maintainer of tokio suggested to use the async-stream crate:

Frankly, the easiest way here is that you just use the async-stream crate to create the stream, rather than add something more complex to Tokio.

Unfortunately, I can't find any implementation of this.

catalinsh commented 1 month ago

I made a working example here.