llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
28.78k stars 11.9k forks source link

[llvm][Support] Implement raw_socket_stream::read with optional timeout #92308

Closed cpsughrue closed 3 months ago

cpsughrue commented 5 months ago

This PR implements raw_socket_stream::read, which overloads the base class raw_fd_stream::read. raw_socket_stream::read provides a way to timeout the underlying ::read. The timeout functionality was not added to raw_fd_stream::read to avoid needlessly increasing compile times and allow for convenient code reuse with raw_socket_stream::accept, which also requires timeout functionality. This PR supports the module build daemon and will help guarantee it never becomes a zombie process.

github-actions[bot] commented 5 months ago

:white_check_mark: With the latest revision this PR passed the C/C++ code formatter.

llvmbot commented 5 months ago

@llvm/pr-subscribers-llvm-support

Author: Connor Sughrue (cpsughrue)

Changes --- Full diff: https://github.com/llvm/llvm-project/pull/92308.diff 3 Files Affected: - (modified) llvm/include/llvm/Support/raw_socket_stream.h (+18-3) - (modified) llvm/lib/Support/raw_socket_stream.cpp (+70-23) - (modified) llvm/unittests/Support/raw_socket_stream_test.cpp (+96-11) ``````````diff diff --git a/llvm/include/llvm/Support/raw_socket_stream.h b/llvm/include/llvm/Support/raw_socket_stream.h index bddd47eb75e1a..225980cb28a42 100644 --- a/llvm/include/llvm/Support/raw_socket_stream.h +++ b/llvm/include/llvm/Support/raw_socket_stream.h @@ -92,10 +92,11 @@ class ListeningSocket { /// Accepts an incoming connection on the listening socket. This method can /// optionally either block until a connection is available or timeout after a /// specified amount of time has passed. By default the method will block - /// until the socket has recieved a connection. + /// until the socket has recieved a connection. If the accept timesout this + /// method will return std::errc:timed_out /// /// \param Timeout An optional timeout duration in milliseconds. Setting - /// Timeout to -1 causes accept to block indefinitely + /// Timeout to a negative number causes ::accept to block indefinitely /// Expected> accept(std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1)); @@ -124,11 +125,25 @@ class raw_socket_stream : public raw_fd_stream { public: raw_socket_stream(int SocketFD); + ~raw_socket_stream(); + /// Create a \p raw_socket_stream connected to the UNIX domain socket at \p /// SocketPath. static Expected> createConnectedUnix(StringRef SocketPath); - ~raw_socket_stream(); + + /// Attempt to read from the raw_socket_stream's file descriptor. This method + /// can optionally either block until data is read or an error has occurred or + /// timeout after a specified amount of time has passed. By default the method + /// will block until the socket has read data or encountered an error. If the + /// read timesout this method will return std::errc:timed_out + /// + /// \param Timeout An optional timeout duration in milliseconds + /// \param Ptr The start of the buffer that will hold any read data + /// \param Size The number of bytes to be read + /// + Expected readFromSocket( + std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1)); }; } // end namespace llvm diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp index 549d537709bf2..063f6fc366da9 100644 --- a/llvm/lib/Support/raw_socket_stream.cpp +++ b/llvm/lib/Support/raw_socket_stream.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #ifndef _WIN32 @@ -177,22 +178,31 @@ Expected ListeningSocket::createUnix(StringRef SocketPath, #endif // _WIN32 } -Expected> -ListeningSocket::accept(std::chrono::milliseconds Timeout) { - - struct pollfd FDs[2]; - FDs[0].events = POLLIN; +// If a file descriptor being monitored by poll is closed by another thread, the +// result is unspecified. In the case poll does not unblock and return when +// ActiveFD is closed you can provide another file descriptor via CancelFD that +// when written to will cause poll to return. Typically CancelFD is the read end +// of a unidirectional pipe. +static llvm::Error manageTimeout(std::chrono::milliseconds Timeout, + std::function getActiveFD, + std::optional CancelFD = std::nullopt) { + struct pollfd FD[2]; + FD[0].events = POLLIN; #ifdef _WIN32 - SOCKET WinServerSock = _get_osfhandle(FD); - FDs[0].fd = WinServerSock; + SOCKET WinServerSock = _get_osfhandle(getActiveFD()); + FD[0].fd = WinServerSock; #else - FDs[0].fd = FD; + FD[0].fd = getActiveFD(); #endif - FDs[1].events = POLLIN; - FDs[1].fd = PipeFD[0]; + uint8_t FDCount = 1; + if (CancelFD.has_value()) { + FD[1].events = POLLIN; + FD[1].fd = CancelFD.value(); + FDCount++; + } - // Keep track of how much time has passed in case poll is interupted by a - // signal and needs to be recalled + // Keep track of how much time has passed in case ::poll or WSAPoll are + // interupted by a signal and need to be recalled int RemainingTime = Timeout.count(); std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0); int PollStatus = -1; @@ -200,20 +210,20 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) { while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) { if (Timeout.count() != -1) RemainingTime -= ElapsedTime.count(); - auto Start = std::chrono::steady_clock::now(); + #ifdef _WIN32 - PollStatus = WSAPoll(FDs, 2, RemainingTime); + PollStatus = WSAPoll(FD, FDCount, RemainingTime); #else - PollStatus = ::poll(FDs, 2, RemainingTime); + PollStatus = ::poll(FD, FDCount, RemainingTime); #endif - // If FD equals -1 then ListeningSocket::shutdown has been called and it is - // appropriate to return operation_canceled - if (FD.load() == -1) + + // If ActiveFD equals -1 or CancelFD has data to be read then the operation + // has been canceled by another thread + if (getActiveFD() == -1 || FD[1].revents & POLLIN) return llvm::make_error( std::make_error_code(std::errc::operation_canceled), "Accept canceled"); - #if _WIN32 if (PollStatus == SOCKET_ERROR) { #else @@ -222,14 +232,14 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) { std::error_code PollErrCode = getLastSocketErrorCode(); // Ignore EINTR (signal occured before any request event) and retry if (PollErrCode != std::errc::interrupted) - return llvm::make_error(PollErrCode, "FD poll failed"); + return llvm::make_error(PollErrCode, "poll failed"); } if (PollStatus == 0) return llvm::make_error( std::make_error_code(std::errc::timed_out), "No client requests within timeout window"); - if (FDs[0].revents & POLLNVAL) + if (FD[0].revents & POLLNVAL) return llvm::make_error( std::make_error_code(std::errc::bad_file_descriptor)); @@ -237,10 +247,19 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) { ElapsedTime += std::chrono::duration_cast(Stop - Start); } + return llvm::Error::success(); +} + +Expected> +ListeningSocket::accept(std::chrono::milliseconds Timeout) { + auto getActiveFD = [this]() -> int { return FD; }; + llvm::Error TimeoutErr = manageTimeout(Timeout, getActiveFD, PipeFD[0]); + if (TimeoutErr) + return TimeoutErr; int AcceptFD; #ifdef _WIN32 - SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL); + SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL); AcceptFD = _open_osfhandle(WinAcceptSock, 0); #else AcceptFD = ::accept(FD, NULL, NULL); @@ -295,6 +314,8 @@ ListeningSocket::~ListeningSocket() { raw_socket_stream::raw_socket_stream(int SocketFD) : raw_fd_stream(SocketFD, true) {} +raw_socket_stream::~raw_socket_stream() {} + Expected> raw_socket_stream::createConnectedUnix(StringRef SocketPath) { #ifdef _WIN32 @@ -306,4 +327,30 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) { return std::make_unique(*FD); } -raw_socket_stream::~raw_socket_stream() {} +Expected +raw_socket_stream::readFromSocket(std::chrono::milliseconds Timeout) { + auto getActiveFD = [this]() -> int { return this->get_fd(); }; + llvm::Error TimeoutErr = manageTimeout(Timeout, getActiveFD); + if (TimeoutErr) + return TimeoutErr; + + std::vector Buffer; + constexpr ssize_t TmpBufferSize = 1024; + char TmpBuffer[TmpBufferSize]; + + while (true) { + std::memset(TmpBuffer, 0, TmpBufferSize); + ssize_t BytesRead = this->read(TmpBuffer, TmpBufferSize); + if (BytesRead == -1) + return llvm::make_error(this->error(), "read failed"); + else if (BytesRead == 0) + break; + else + Buffer.insert(Buffer.end(), TmpBuffer, TmpBuffer + BytesRead); + // All available bytes have been read. Another call to read will block + if (BytesRead < TmpBufferSize) + break; + } + + return std::string(Buffer.begin(), Buffer.end()); +} diff --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp index c4e8cfbbe7e6a..1b8f85f88f1af 100644 --- a/llvm/unittests/Support/raw_socket_stream_test.cpp +++ b/llvm/unittests/Support/raw_socket_stream_test.cpp @@ -58,21 +58,106 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) { Client << "01234567"; Client.flush(); - char Bytes[8]; - ssize_t BytesRead = Server.read(Bytes, 8); + llvm::Expected MaybeText = Server.readFromSocket(); + ASSERT_THAT_EXPECTED(MaybeText, llvm::Succeeded()); + ASSERT_EQ("01234567", *MaybeText); +} + +TEST(raw_socket_streamTest, LARGE_READ) { + if (!hasUnixSocketSupport()) + GTEST_SKIP(); + + SmallString<100> SocketPath; + llvm::sys::fs::createUniquePath("large_read.sock", SocketPath, true); + + // Make sure socket file does not exist. May still be there from the last test + std::remove(SocketPath.c_str()); + + Expected MaybeServerListener = + ListeningSocket::createUnix(SocketPath); + ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded()); + ListeningSocket ServerListener = std::move(*MaybeServerListener); + + Expected> MaybeClient = + raw_socket_stream::createConnectedUnix(SocketPath); + ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded()); + raw_socket_stream &Client = **MaybeClient; + + Expected> MaybeServer = + ServerListener.accept(); + ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded()); + raw_socket_stream &Server = **MaybeServer; + + // raw_socket_stream::readFromSocket pre-allocates a buffer 1024 bytes large. + // Test to make sure readFromSocket can handle messages larger then size of + // pre-allocated block + constexpr int TextLength = 1342; + constexpr char Text[TextLength] = + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do " + "eiusmod tempor incididunt ut labore et dolore magna aliqua. Vel orci " + "porta non pulvinar neque laoreet suspendisse interdum consectetur. " + "Nulla facilisi etiam dignissim diam quis. Porttitor massa id neque " + "aliquam vestibulum morbi blandit cursus. Purus viverra accumsan in " + "nisl. Nunc non blandit massa enim nec dui nunc mattis enim. Rhoncus " + "dolor purus non enim praesent elementum facilisis leo. Parturient " + "montes nascetur ridiculus mus mauris. Urna condimentum mattis " + "pellentesque id nibh tortor id aliquet lectus. Orci eu lobortis " + "elementum nibh. Sagittis eu volutpat odio facilisis. Molestie a " + "iaculis at erat pellentesque adipiscing. Tincidunt augue interdum " + "velit euismod in pellentesque massa placerat. Cras ornare arcu dui " + "vivamus arcu felis bibendum ut tristique. Tellus elementum sagittis " + "vitae et leo duis. Scelerisque fermentum dui faucibus in ornare " + "quam. Ipsum a arcu cursus vitae congue. Sit amet nisl suscipit " + "adipiscing. Sociis natoque penatibus et magnis. Cras semper auctor " + "neque vitae tempus quam pellentesque. Neque gravida in fermentum et " + "sollicitudin ac orci phasellus egestas. Vitae suscipit tellus mauris " + "a diam maecenas sed. Lectus arcu bibendum at varius vel pharetra. " + "Dignissim sodales ut eu sem integer vitae justo. Id cursus metus " + "aliquam eleifend mi."; + Client << Text; + Client.flush(); + + llvm::Expected MaybeText = Server.readFromSocket(); + ASSERT_THAT_EXPECTED(MaybeText, llvm::Succeeded()); + ASSERT_EQ(Text, *MaybeText); +} - std::string string(Bytes, 8); +TEST(raw_socket_streamTest, READ_WITH_TIMEOUT) { + if (!hasUnixSocketSupport()) + GTEST_SKIP(); + + SmallString<100> SocketPath; + llvm::sys::fs::createUniquePath("read_with_timeout.sock", SocketPath, true); - ASSERT_EQ(8, BytesRead); - ASSERT_EQ("01234567", string); + // Make sure socket file does not exist. May still be there from the last test + std::remove(SocketPath.c_str()); + + Expected MaybeServerListener = + ListeningSocket::createUnix(SocketPath); + ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded()); + ListeningSocket ServerListener = std::move(*MaybeServerListener); + + Expected> MaybeClient = + raw_socket_stream::createConnectedUnix(SocketPath); + ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded()); + + Expected> MaybeServer = + ServerListener.accept(); + ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded()); + raw_socket_stream &Server = **MaybeServer; + + llvm::Expected MaybeBytesRead = + Server.readFromSocket(std::chrono::milliseconds(100)); + ASSERT_EQ(llvm::errorToErrorCode(MaybeBytesRead.takeError()), + std::errc::timed_out); } -TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) { +TEST(raw_socket_streamTest, ACCEPT_WITH_TIMEOUT) { if (!hasUnixSocketSupport()) GTEST_SKIP(); SmallString<100> SocketPath; - llvm::sys::fs::createUniquePath("timout_provided.sock", SocketPath, true); + llvm::sys::fs::createUniquePath("accept_with_timeout.sock", SocketPath, true); // Make sure socket file does not exist. May still be there from the last test std::remove(SocketPath.c_str()); @@ -82,19 +167,19 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) { ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded()); ListeningSocket ServerListener = std::move(*MaybeServerListener); - std::chrono::milliseconds Timeout = std::chrono::milliseconds(100); Expected> MaybeServer = - ServerListener.accept(Timeout); + ServerListener.accept(std::chrono::milliseconds(100)); ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()), std::errc::timed_out); } -TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) { +TEST(raw_socket_streamTest, ACCEPT_WITH_SHUTDOWN) { if (!hasUnixSocketSupport()) GTEST_SKIP(); SmallString<100> SocketPath; - llvm::sys::fs::createUniquePath("fd_closed.sock", SocketPath, true); + llvm::sys::fs::createUniquePath("accept_with_shutdown.sock", SocketPath, + true); // Make sure socket file does not exist. May still be there from the last test std::remove(SocketPath.c_str()); ``````````