Open leiless opened 1 year ago
Hi, guys. I've implemented the Windows AF_UNIX domain socket. But in order to bring UDS to Windows, I have to modify the tokio and mio code (by exporting some internal struct/functions).
Here is my Windows UDS implementation:
uds_listener.rs
use std::fmt::Formatter;
use std::path::Path;
use mio::{Interest, Registry, Token};
use windows::Win32::Networking::WinSock;
pub struct SocketWrapper(pub WinSock::SOCKET);
impl std::os::windows::io::AsRawSocket for SocketWrapper {
fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
self.0.0 as _
}
}
impl std::os::windows::io::IntoRawSocket for SocketWrapper {
fn into_raw_socket(self) -> std::os::windows::io::RawSocket {
self.0.0 as _
}
}
impl std::os::windows::io::FromRawSocket for SocketWrapper {
unsafe fn from_raw_socket(sock: std::os::windows::io::RawSocket) -> Self {
Self(WinSock::SOCKET(sock as _))
}
}
// Implementation references
// https://doc.rust-lang.org/std/os/unix/net/struct.UnixListener.html
// https://github.com/tokio-rs/mio/blob/master/src/sys/unix/uds/listener.rs
pub struct UnixListener {
io: mio::io_source::IoSource<SocketWrapper>,
}
impl UnixListener {
#[inline(always)]
// https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/ms632663(v=vs.85)
fn make_word(lo: u8, hi: u8) -> u16 {
lo as u16 | ((hi as u16) << 8)
}
#[inline(always)]
fn wsa_get_last_error() -> i32 {
let last_error = unsafe { WinSock::WSAGetLastError().0 };
assert_ne!(last_error, 0);
last_error
}
#[inline]
fn socket_path_to_sun_path(socket_path: &str) -> [u8; 108] {
let slice = socket_path.as_bytes();
let mut array = [0; 108];
let length = socket_path.len();
array[..length].copy_from_slice(&slice[..length]);
array[length] = 0; // Mark the end
array
}
pub fn bind<P: AsRef<Path>>(path: P, backlog: i32) -> anyhow::Result<Self> {
unsafe {
let wsa_ver = Self::make_word(2, 2);
let mut wsa_data = WinSock::WSADATA::default();
let rc = WinSock::WSAStartup(
wsa_ver,
std::ptr::addr_of_mut!(wsa_data),
);
if rc != 0 {
return Err(anyhow::anyhow!("WSAStartup() error {}", Self::wsa_get_last_error()));
}
assert_eq!(wsa_data.wVersion, wsa_ver);
let socket_fd = WinSock::socket(
WinSock::AF_UNIX as _,
WinSock::SOCK_STREAM,
WinSock::IPPROTO_TCP.0,
);
if socket_fd == WinSock::INVALID_SOCKET {
return Err(anyhow::anyhow!("socket() error {}", Self::wsa_get_last_error()));
}
let socket_path = path.as_ref().as_os_str().to_str().ok_or_else(
|| anyhow::anyhow!("{:?} is not a valid UTF-8 path", path.as_ref())
)?;
if socket_path.len() >= 108 {
return Err(anyhow::anyhow!("socket path length {} is too long", socket_path.len()));
}
let sun_path = Self::socket_path_to_sun_path(socket_path);
let socket_addr = WinSock::SOCKADDR_UN {
sun_family: WinSock::ADDRESS_FAMILY(WinSock::AF_UNIX),
sun_path,
};
let rc = WinSock::bind(
socket_fd,
std::ptr::addr_of!(socket_addr) as _,
std::mem::size_of::<WinSock::SOCKADDR_UN>() as _,
);
if rc == WinSock::SOCKET_ERROR {
return Err(anyhow::anyhow!("bind() error {}", Self::wsa_get_last_error()));
}
let rc = WinSock::listen(socket_fd, backlog);
if rc == WinSock::SOCKET_ERROR {
return Err(anyhow::anyhow!("listen() error {}", Self::wsa_get_last_error()));
}
Ok(Self {
io: mio::io_source::IoSource::new(SocketWrapper(socket_fd))
})
}
}
#[inline(always)]
fn to_socket(&self) -> WinSock::SOCKET {
WinSock::SOCKET(self.io.0.0)
}
pub fn accept(&self) -> std::io::Result<(crate::uds_stream::UnixStream, WinSock::SOCKADDR_UN)> {
unsafe {
let mut addr = WinSock::SOCKADDR_UN::default();
let mut addr_len = std::mem::size_of::<WinSock::SOCKADDR_UN>() as i32;
let client_fd = WinSock::accept(
self.to_socket(),
Some(std::ptr::addr_of_mut!(addr) as _),
Some(std::ptr::addr_of_mut!(addr_len)),
);
if client_fd == WinSock::INVALID_SOCKET {
return Err(std::io::Error::from_raw_os_error(Self::wsa_get_last_error()));
}
// Make the client fd as non-blocking
let mut i_mode = 1;
let rc = WinSock::ioctlsocket(
client_fd,
WinSock::FIONBIO,
std::ptr::addr_of_mut!(i_mode),
);
if rc == WinSock::SOCKET_ERROR {
let last_error_code = Self::wsa_get_last_error();
let rc = WinSock::closesocket(client_fd);
if rc == WinSock::SOCKET_ERROR {
eprintln!("closesocket() {} error: {}", client_fd.0, Self::wsa_get_last_error());
}
return Err(std::io::Error::from_raw_os_error(last_error_code));
}
Ok((crate::uds_stream::UnixStream::from_socket(client_fd)?, addr))
}
}
pub fn local_addr(&self) -> anyhow::Result<WinSock::SOCKADDR_UN> {
let mut addr = WinSock::SOCKADDR_UN::default();
let mut addr_len = std::mem::size_of::<WinSock::SOCKADDR_UN>() as i32;
unsafe {
let rc = WinSock::getsockname(
self.to_socket(),
std::ptr::addr_of_mut!(addr) as _,
std::ptr::addr_of_mut!(addr_len),
);
if rc == WinSock::SOCKET_ERROR {
return Err(anyhow::anyhow!("getsockname() error {}", Self::wsa_get_last_error()));
}
}
Ok(addr)
}
}
impl mio::event::Source for UnixListener {
fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> {
self.io.register(registry, token, interests)
}
fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> {
self.io.reregister(registry, token, interests)
}
fn deregister(&mut self, registry: &Registry) -> std::io::Result<()> {
self.io.deregister(registry)
}
}
impl std::fmt::Debug for UnixListener {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut builder = f.debug_struct("UnixListener");
builder.field("fd", &self.io.0);
if let Ok(addr) = self.local_addr() {
builder.field("local", &addr);
}
builder.finish()
}
}
uds_stream.rs
use std::io::{Error, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};
use mio::{Interest, Registry, Token};
use tokio::io::ReadBuf;
use windows::Win32::Networking::WinSock;
struct IoSourcedSocket {
io: mio::io_source::IoSource<crate::uds_listener::SocketWrapper>,
}
#[inline(always)]
fn wsa_get_last_error() -> i32 {
let last_error = unsafe { WinSock::WSAGetLastError().0 };
assert_ne!(last_error, 0);
last_error
}
impl IoSourcedSocket {
fn from_socket(socket: WinSock::SOCKET) -> Self {
Self {
io: mio::io_source::IoSource::new(
crate::uds_listener::SocketWrapper(socket)
),
}
}
}
impl std::io::Read for IoSourcedSocket {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
<&Self as std::io::Read>::read(&mut &*self, buf)
}
}
impl std::io::Write for IoSourcedSocket {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
<&Self as std::io::Write>::write(&mut &*self, buf)
}
fn flush(&mut self) -> std::io::Result<()> {
<&Self as std::io::Write>::flush(&mut &*self)
}
}
// https://github.com/tokio-rs/mio/blob/master/src/sys/windows/named_pipe.rs#L510
impl<'a> std::io::Read for &'a IoSourcedSocket {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
unsafe {
let rc = WinSock::recv(self.io.0, buf, WinSock::SEND_RECV_FLAGS(0));
if rc == WinSock::SOCKET_ERROR {
Err(Error::from_raw_os_error(wsa_get_last_error()))
} else {
Ok(rc as _)
}
}
}
}
// https://github.com/tokio-rs/mio/blob/master/src/sys/windows/named_pipe.rs#L561
impl<'a> std::io::Write for &'a IoSourcedSocket {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
unsafe {
let rc = WinSock::send(self.io.0, buf, WinSock::SEND_RECV_FLAGS(0));
if rc == WinSock::SOCKET_ERROR {
Err(Error::from_raw_os_error(wsa_get_last_error()))
} else {
Ok(rc as _)
}
}
}
fn flush(&mut self) -> std::io::Result<()> {
// There is no flush / sync on Windows socket 2
Ok(())
}
}
impl mio::event::Source for IoSourcedSocket {
fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> {
self.io.register(registry, token, interests)
}
fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> {
self.io.reregister(registry, token, interests)
}
fn deregister(&mut self, registry: &Registry) -> std::io::Result<()> {
self.io.deregister(registry)
}
}
pub struct UnixStream {
io: tokio::io::poll_evented::PollEvented<IoSourcedSocket>,
}
impl UnixStream {
pub fn from_socket(socket: WinSock::SOCKET) -> std::io::Result<Self> {
Ok(Self {
io: tokio::io::poll_evented::PollEvented::new(
IoSourcedSocket::from_socket(socket)
)?,
})
}
fn local_addr(&self) -> anyhow::Result<WinSock::SOCKADDR_UN> {
let mut addr = WinSock::SOCKADDR_UN::default();
let mut addr_len = std::mem::size_of::<WinSock::SOCKADDR_UN>() as i32;
unsafe {
let rc = WinSock::getsockname(
self.io.io.0,
std::ptr::addr_of_mut!(addr) as _,
std::ptr::addr_of_mut!(addr_len),
);
if rc == WinSock::SOCKET_ERROR {
return Err(anyhow::anyhow!("getsockname() error {}", wsa_get_last_error()));
}
}
Ok(addr)
}
}
impl tonic::transport::server::Connected for UnixStream {
type ConnectInfo = Option<WinSock::SOCKADDR_UN>;
fn connect_info(&self) -> Self::ConnectInfo {
match self.local_addr() {
Ok(sock_addr) => {
Some(sock_addr)
}
Err(_err) => {
None
}
}
}
}
impl tokio::io::AsyncRead for UnixStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
unsafe { self.io.poll_read(cx, buf) }
}
}
impl tokio::io::AsyncWrite for UnixStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
self.io.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize, Error>> {
self.io.poll_write_vectored(cx, bufs)
}
}
uds.rs
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use windows::Win32::Networking::WinSock;
pub struct UnixListener {
io: tokio::io::poll_evented::PollEvented<crate::uds_listener::UnixListener>,
}
impl UnixListener {
pub fn bind<P: AsRef<Path>>(path: P, backlog: i32) -> anyhow::Result<Self> {
let listener = crate::uds_listener::UnixListener::bind(path, backlog)?;
let io = tokio::io::poll_evented::PollEvented::new(listener)?;
Ok(Self {
io
})
}
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<std::io::Result<(crate::uds_stream::UnixStream, WinSock::SOCKADDR_UN)>> {
let (stream, addr) = futures::ready!(self.io.registration().poll_read_io(cx, || self.io.accept()))?;
Poll::Ready(Ok((stream, addr)))
}
}
pub struct UnixListenerStream {
inner: UnixListener,
}
impl UnixListenerStream {
pub fn new(listener: UnixListener) -> Self {
Self { inner: listener }
}
pub fn into_inner(self) -> UnixListener {
self.inner
}
}
impl futures::stream::Stream for UnixListenerStream {
type Item = std::io::Result<crate::uds_stream::UnixStream>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match futures::ready!(self.inner.poll_accept(cx)) {
Ok((stream, _sock_addr)) => {
Poll::Ready(Some(Ok(stream)))
}
Err(err) => {
Poll::Ready(Some(Err(err)))
}
}
}
}
The code cannot compile due to some tonic, tokio version incompatibilities(e.g. rustc
complains the AsyncRead/AsyncWrite trait is not implemented on UnixStream
but it's actually implemented), and I don't know why.
Anyone can help with this?
How can I use custom tokio/mio (with some internal struct/funcs exported) along the tonic to get this working? The main difficulty here is tonic depends on tokio, and I changed the tonic-tokio to my own tokio crate, but it still cannot compile.
I'd like to know how to wrap named pipe into stream too. For UDS UnixListenerStream
exists but I don't see anything for named pipe.
In https://github.com/tokio-rs/tokio/issues/6591#issuecomment-2134427618, the maintainer of tokio suggested to use the async-stream crate:
Frankly, the easiest way here is that you just use the async-stream crate to create the stream, rather than add something more complex to Tokio.
Unfortunately, I can't find any implementation of this.
Hi all.
I'd like to know how could I implement an incoming stream with async read/write
tokio::net::windows::named_pipe
?I'm new to tokio and tonic, and my use case requires me to use Windows named pipe as the gRPC backend transport.
I know the named pipe stream has to implement the
futures::stream::Stream
trait (i.e. thepoll_next
method), but I have no idea of how to write the working code.Do you guys have any hint for this?