quininer / tokio-rustls

Asynchronous TLS/SSL streams for Tokio using Rustls.
142 stars 38 forks source link

Add a bidirectional `TlsStream` wrapper #46

Closed djc closed 4 years ago

djc commented 4 years ago

For some projects, it can be convenient to have a single type that abstracts over client and server versions of the TlsStream. This also makes integrating TLS support easier, since TcpStream is bidirectional, but there are two different TlsStream types in this project.

The code below seems useful in making this a little more ergonomic, at a supposedly very minor performance cost (could apply some #[inline] directives to make it even better):

enum TlsStream<T> {
    Client(TlsClientStream<T>),
    Server(TlsServerStream<T>),
}

impl<T> TlsStream<T> {
    fn tcp(&self) -> &T {
        use TlsStream::*;
        match self {
            Client(io) => io.get_ref().0,
            Server(io) => io.get_ref().0,
        }
    }
}

impl<T> From<TlsClientStream<T>> for TlsStream<T> {
    fn from(s: TlsClientStream<T>) -> Self {
        Self::Client(s)
    }
}

impl<T> From<TlsServerStream<T>> for TlsStream<T> {
    fn from(s: TlsServerStream<T>) -> Self {
        Self::Server(s)
    }
}

impl<T> io::Read for TlsStream<T>
where
    T: AsyncRead + AsyncWrite + io::Read,
{
    fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
        use TlsStream::*;
        match self {
            Client(io) => io.read(buf),
            Server(io) => io.read(buf),
        }
    }
}

impl<T> io::Write for TlsStream<T>
where
    T: AsyncRead + AsyncWrite + io::Write,
{
    fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
        use TlsStream::*;
        match self {
            Client(io) => io.write(buf),
            Server(io) => io.write(buf),
        }
    }

    fn flush(&mut self) -> Result<(), io::Error> {
        use TlsStream::*;
        match self {
            Client(io) => io.flush(),
            Server(io) => io.flush(),
        }
    }
}

impl<T> AsyncRead for TlsStream<T> where T: AsyncRead + AsyncWrite {}

impl<T> AsyncWrite for TlsStream<T>
where
    T: AsyncRead + AsyncWrite,
{
    fn shutdown(&mut self) -> Poll<(), io::Error> {
        use TlsStream::*;
        match self {
            Client(io) => io.shutdown(),
            Server(io) => io.shutdown(),
        }
    }
}
quininer commented 4 years ago

It makes sense, are you willing to open a PR?

djc commented 4 years ago

Done in #47.