From a138fb55c40f6f53c7bec076bb97a3a70ceab2f6 Mon Sep 17 00:00:00 2001 From: Lutz Justen Date: Mon, 19 Dec 2022 23:02:36 +0100 Subject: [PATCH] [cdc_rsync] Add support for ServerSocket on Windows (#48) Makes ServerSocket multi-platform, mainly by working around some small API differences. The code is largely the same, there should be no differences on Linux. Also moves WSAStartup() and WSACleanup() up to the Socket level as static methods because it's used by both ClientSocket and ServerSocket, and because it doesn't make sense to do that in the socket class as that would prevent one from using several sockets. --- all_files.vcxitems | 1 + cdc_rsync/base/BUILD | 8 ++ cdc_rsync/base/socket.cc | 65 ++++++++++ cdc_rsync/base/socket.h | 14 +++ cdc_rsync/cdc_rsync_client.cc | 6 + cdc_rsync/cdc_rsync_client.h | 1 + cdc_rsync/client_socket.cc | 19 +-- cdc_rsync/client_socket.h | 2 +- cdc_rsync_server/BUILD | 8 +- cdc_rsync_server/cdc_rsync_server.cc | 15 ++- cdc_rsync_server/cdc_rsync_server.h | 3 + cdc_rsync_server/server_socket.cc | 178 ++++++++++++++++++--------- cdc_rsync_server/server_socket.h | 6 +- 13 files changed, 242 insertions(+), 84 deletions(-) create mode 100644 cdc_rsync/base/socket.cc diff --git a/all_files.vcxitems b/all_files.vcxitems index 240de11..1dcab1d 100644 --- a/all_files.vcxitems +++ b/all_files.vcxitems @@ -15,6 +15,7 @@ + diff --git a/cdc_rsync/base/BUILD b/cdc_rsync/base/BUILD index 655a739..d68b392 100644 --- a/cdc_rsync/base/BUILD +++ b/cdc_rsync/base/BUILD @@ -80,7 +80,15 @@ cc_library( cc_library( name = "socket", + srcs = ["socket.cc"], hdrs = ["socket.h"], + deps = [ + "//common:log", + "//common:platform", + "//common:status", + "//common:util", + "@com_google_absl//absl/status", + ], ) filegroup( diff --git a/cdc_rsync/base/socket.cc b/cdc_rsync/base/socket.cc new file mode 100644 index 0000000..e02b9ed --- /dev/null +++ b/cdc_rsync/base/socket.cc @@ -0,0 +1,65 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cdc_rsync/base/socket.h" + +#include "common/log.h" +#include "common/platform.h" +#include "common/status.h" +#include "common/util.h" + +#if PLATFORM_WINDOWS +#include +#endif + +namespace cdc_ft { + +// static +absl::Status Socket::Initialize() { +#if PLATFORM_WINDOWS + WSADATA wsaData; + const int result = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (result != 0) { + return MakeStatus("WSAStartup() failed: %s", Util::GetWin32Error(result)); + } + return absl::OkStatus(); +#elif PLATFORM_LINUX + return absl::OkStatus(); +#endif +} + +// static +absl::Status Socket::Shutdown() { +#if PLATFORM_WINDOWS + const int result = WSACleanup(); + if (result == SOCKET_ERROR) { + return MakeStatus("WSACleanup() failed: %s", + Util::GetWin32Error(WSAGetLastError())); + } + return absl::OkStatus(); +#elif PLATFORM_LINUX + return absl::OkStatus(); +#endif +} + +SocketFinalizer::~SocketFinalizer() { + absl::Status status = Socket::Shutdown(); + if (!status.ok()) { + LOG_ERROR("Socket shutdown failed: %s", status.message()) + } +}; + +} // namespace cdc_ft diff --git a/cdc_rsync/base/socket.h b/cdc_rsync/base/socket.h index c156dab..9c6c876 100644 --- a/cdc_rsync/base/socket.h +++ b/cdc_rsync/base/socket.h @@ -26,6 +26,14 @@ class Socket { Socket() = default; virtual ~Socket() = default; + // Calls WSAStartup() on Windows, no-op on Linux. + // Must be called before using sockets. + static absl::Status Initialize(); + + // Calls WSACleanup() on Windows, no-op on Linux. + // Must be called after using sockets. + static absl::Status Shutdown(); + // Send data to the socket. virtual absl::Status Send(const void* buffer, size_t size) = 0; @@ -40,6 +48,12 @@ class Socket { size_t* bytes_received) = 0; }; +// Convenience class that calls Shutdown() on destruction. Logs on errors. +class SocketFinalizer { + public: + ~SocketFinalizer(); +}; + } // namespace cdc_ft #endif // CDC_RSYNC_BASE_SOCKET_H_ diff --git a/cdc_rsync/cdc_rsync_client.cc b/cdc_rsync/cdc_rsync_client.cc index 1dccd6c..525890c 100644 --- a/cdc_rsync/cdc_rsync_client.cc +++ b/cdc_rsync/cdc_rsync_client.cc @@ -263,6 +263,12 @@ absl::Status CdcRsyncClient::StartServer() { return SetTag(MakeStatus("Redeploy server"), Tag::kDeployServer); } + status = Socket::Initialize(); + if (!status.ok()) { + return WrapStatus(status, "Failed to initialize sockets"); + } + socket_finalizer_ = std::make_unique(); + assert(is_server_listening_); status = socket_.Connect(port); if (!status.ok()) { diff --git a/cdc_rsync/cdc_rsync_client.h b/cdc_rsync/cdc_rsync_client.h index 9e24b78..055fb69 100644 --- a/cdc_rsync/cdc_rsync_client.h +++ b/cdc_rsync/cdc_rsync_client.h @@ -123,6 +123,7 @@ class CdcRsyncClient { WinProcessFactory process_factory_; RemoteUtil remote_util_; PortManager port_manager_; + std::unique_ptr socket_finalizer_; ClientSocket socket_; MessagePump message_pump_{&socket_, MessagePump::PacketReceivedDelegate()}; ConsoleProgressPrinter printer_; diff --git a/cdc_rsync/client_socket.cc b/cdc_rsync/client_socket.cc index c124ed7..873d3f4 100644 --- a/cdc_rsync/client_socket.cc +++ b/cdc_rsync/client_socket.cc @@ -39,10 +39,10 @@ absl::Status MakeSocketStatus(const char* message) { } // namespace -struct SocketInfo { +struct ClientSocketInfo { SOCKET socket; - SocketInfo() : socket(INVALID_SOCKET) {} + ClientSocketInfo() : socket(INVALID_SOCKET) {} }; ClientSocket::ClientSocket() = default; @@ -50,12 +50,6 @@ ClientSocket::ClientSocket() = default; ClientSocket::~ClientSocket() { Disconnect(); } absl::Status ClientSocket::Connect(int port) { - WSADATA wsaData; - int result = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (result != 0) { - return MakeStatus("WSAStartup() failed: %i", result); - } - addrinfo hints; ZeroMemory(&hints, sizeof(hints)); hints.ai_family = AF_INET; @@ -64,14 +58,13 @@ absl::Status ClientSocket::Connect(int port) { // Resolve the server address and port. addrinfo* addr_infos = nullptr; - result = getaddrinfo("localhost", std::to_string(port).c_str(), &hints, - &addr_infos); + int result = getaddrinfo("localhost", std::to_string(port).c_str(), &hints, + &addr_infos); if (result != 0) { - WSACleanup(); return MakeStatus("getaddrinfo() failed: %i", result); } - socket_info_ = std::make_unique(); + socket_info_ = std::make_unique(); int count = 0; for (addrinfo* curr = addr_infos; curr; curr = curr->ai_next, count++) { socket_info_->socket = @@ -101,7 +94,6 @@ absl::Status ClientSocket::Connect(int port) { if (socket_info_->socket == INVALID_SOCKET) { socket_info_.reset(); - WSACleanup(); return MakeStatus("Unable to connect to port %i", port); } @@ -120,7 +112,6 @@ void ClientSocket::Disconnect() { } socket_info_.reset(); - WSACleanup(); } absl::Status ClientSocket::Send(const void* buffer, size_t size) { diff --git a/cdc_rsync/client_socket.h b/cdc_rsync/client_socket.h index ec6eb91..9be7835 100644 --- a/cdc_rsync/client_socket.h +++ b/cdc_rsync/client_socket.h @@ -45,7 +45,7 @@ class ClientSocket : public Socket { size_t* bytes_received) override; private: - std::unique_ptr socket_info_; + std::unique_ptr socket_info_; }; } // namespace cdc_ft diff --git a/cdc_rsync_server/BUILD b/cdc_rsync_server/BUILD index 5b80201..e6c9551 100644 --- a/cdc_rsync_server/BUILD +++ b/cdc_rsync_server/BUILD @@ -127,11 +127,17 @@ cc_library( name = "server_socket", srcs = ["server_socket.cc"], hdrs = ["server_socket.h"], - target_compatible_with = ["@platforms//os:linux"], + linkopts = select({ + "//tools:windows": [ + "/DEFAULTLIB:Ws2_32.lib", # Sockets, e.g. recv, send, WSA*. + ], + "//conditions:default": [], + }), deps = [ "//cdc_rsync/base:socket", "//common:log", "//common:status", + "//common:util", "@com_google_absl//absl/status", ], ) diff --git a/cdc_rsync_server/cdc_rsync_server.cc b/cdc_rsync_server/cdc_rsync_server.cc index b6f9e04..f7a01f2 100644 --- a/cdc_rsync_server/cdc_rsync_server.cc +++ b/cdc_rsync_server/cdc_rsync_server.cc @@ -148,10 +148,7 @@ PathFilter::Rule::Type ToInternalType( CdcRsyncServer::CdcRsyncServer() = default; -CdcRsyncServer::~CdcRsyncServer() { - message_pump_.reset(); - socket_.reset(); -} +CdcRsyncServer::~CdcRsyncServer() = default; bool CdcRsyncServer::CheckComponents( const std::vector& components) { @@ -173,8 +170,14 @@ bool CdcRsyncServer::CheckComponents( } absl::Status CdcRsyncServer::Run(int port) { + absl::Status status = Socket::Initialize(); + if (!status.ok()) { + return WrapStatus(status, "Failed to initialize sockets"); + } + socket_finalizer_ = std::make_unique(); + socket_ = std::make_unique(); - absl::Status status = socket_->StartListening(port); + status = socket_->StartListening(port); if (!status.ok()) { return WrapStatus(status, "Failed to start listening on port %i", port); } @@ -563,7 +566,7 @@ absl::Status CdcRsyncServer::HandleSendMissingFileData() { // Verify that there is no directory existing with the same name. if (path::Exists(filepath) && path::DirExists(filepath)) { assert(!diff_.extraneous_dirs.empty()); - absl::Status status = path::RemoveFile(filepath); + status = path::RemoveFile(filepath); if (!status.ok()) { return WrapStatus( status, "Failed to remove folder '%s' before creating file '%s'", diff --git a/cdc_rsync_server/cdc_rsync_server.h b/cdc_rsync_server/cdc_rsync_server.h index 0a58549..59c66fd 100644 --- a/cdc_rsync_server/cdc_rsync_server.h +++ b/cdc_rsync_server/cdc_rsync_server.h @@ -32,6 +32,7 @@ namespace cdc_ft { class MessagePump; class ServerSocket; +class SocketFinalizer; class CdcRsyncServer { public: @@ -90,6 +91,8 @@ class CdcRsyncServer { // Used to toggle decompression. void Thread_OnPackageReceived(PacketType type); + // The order determines the correct destruction order, so keep it! + std::unique_ptr socket_finalizer_; std::unique_ptr socket_; std::unique_ptr message_pump_; diff --git a/cdc_rsync_server/server_socket.cc b/cdc_rsync_server/server_socket.cc index 228d31c..f76be6e 100644 --- a/cdc_rsync_server/server_socket.cc +++ b/cdc_rsync_server/server_socket.cc @@ -14,20 +14,72 @@ #include "cdc_rsync_server/server_socket.h" +#include "common/log.h" +#include "common/platform.h" +#include "common/status.h" +#include "common/util.h" + +#if PLATFORM_WINDOWS + +#include + +#elif PLATFORM_LINUX + #include #include #include #include -#include "common/log.h" -#include "common/status.h" +#endif namespace cdc_ft { - namespace { -int kInvalidFd = -1; +#if PLATFORM_WINDOWS + +using SocketType = SOCKET; +using SockAddrType = SOCKADDR; +constexpr SocketType kInvalidSocket = INVALID_SOCKET; +constexpr int kSocketError = SOCKET_ERROR; +constexpr int kSendingEnd = SD_SEND; + +constexpr int kErrAgain = WSAEWOULDBLOCK; // There's no EAGAIN on Windows. +constexpr int kErrWouldBlock = WSAEWOULDBLOCK; +constexpr int kErrAddrInUse = WSAEADDRINUSE; + +int GetLastError() { return WSAGetLastError(); } +std::string GetErrorStr(int err) { return Util::GetWin32Error(err); } +void Close(SocketType* socket) { + if (*socket != kInvalidSocket) { + closesocket(*socket); + *socket = kInvalidSocket; + } +} + +// Not necessary on Windows. +#define HANDLE_EINTR(x) (x) + +#elif PLATFORM_LINUX + +using SocketType = int; +using SockAddrType = sockaddr; +constexpr SocketType kInvalidSocket = -1; +constexpr int kSocketError = -1; +constexpr int kSendingEnd = SHUT_WR; + +constexpr int kErrAgain = EAGAIN; +constexpr int kErrWouldBlock = EWOULDBLOCK; +constexpr int kErrAddrInUse = EADDRINUSE; + +int GetLastError() { return errno; } +std::string GetErrorStr(int err) { return strerror(err); } +void Close(SocketType* socket) { + if (*socket != kInvalidSocket) { + close(*socket); + *socket = kInvalidSocket; + } +} // Keep re-evaluating the expression |x| while it returns EINTR. #define HANDLE_EINTR(x) \ @@ -39,10 +91,22 @@ int kInvalidFd = -1; eintr_wrapper_result; \ }) +#endif + +std::string GetLastErrorStr() { return GetErrorStr(GetLastError()); } + } // namespace +struct ServerSocketInfo { + // Listening socket file descriptor (where new connections are accepted). + SocketType listen_sock = kInvalidSocket; + + // Connection socket file descriptor (where data is sent to/received from). + SocketType conn_sock = kInvalidSocket; +}; + ServerSocket::ServerSocket() - : Socket(), listen_sockfd_(kInvalidFd), conn_sockfd_(kInvalidFd) {} + : Socket(), socket_info_(std::make_unique()) {} ServerSocket::~ServerSocket() { Disconnect(); @@ -50,25 +114,26 @@ ServerSocket::~ServerSocket() { } absl::Status ServerSocket::StartListening(int port) { - if (listen_sockfd_ != kInvalidFd) { + if (socket_info_->listen_sock != kInvalidSocket) { return MakeStatus("Already listening"); } LOG_DEBUG("Open socket"); - listen_sockfd_ = socket(AF_INET, SOCK_STREAM, 0); - if (listen_sockfd_ < 0) { - listen_sockfd_ = kInvalidFd; - return MakeStatus("socket() failed: %s", strerror(errno)); + socket_info_->listen_sock = socket(AF_INET, SOCK_STREAM, 0); + if (socket_info_->listen_sock == kInvalidSocket) { + return MakeStatus("Creating listen socket failed: %s", GetLastErrorStr()); } // If the program terminates abnormally, the socket might remain in a // TIME_WAIT state and report "address already in use" on bind(). Setting // SO_REUSEADDR works around that. See // https://hea-www.harvard.edu/~fine/Tech/addrinuse.html - int enable = 1; - if (setsockopt(listen_sockfd_, SOL_SOCKET, SO_REUSEADDR, &enable, - sizeof(enable)) < 0) { - LOG_DEBUG("setsockopt() failed"); + const int enable = 1; + int result = + setsockopt(socket_info_->listen_sock, SOL_SOCKET, SO_REUSEADDR, + reinterpret_cast(&enable), sizeof(enable)); + if (result == kSocketError) { + LOG_DEBUG("Enabling address reusal failed"); } LOG_DEBUG("Bind socket"); @@ -77,46 +142,47 @@ absl::Status ServerSocket::StartListening(int port) { serv_addr.sin_family = AF_INET; serv_addr.sin_addr.s_addr = INADDR_ANY; serv_addr.sin_port = htons(port); - if (bind(listen_sockfd_, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) < - 0) { + + result = bind(socket_info_->listen_sock, + reinterpret_cast(&serv_addr), + sizeof(serv_addr)); + if (result == kSocketError) { + int err = GetLastError(); absl::Status status = - MakeStatus("bind() to port %i failed: %s", port, strerror(errno)); - if (errno == EADDRINUSE) { + MakeStatus("Binding to port %i failed: %s", port, GetErrorStr(err)); + if (err == kErrAddrInUse) { // Happens when two instances are run at the same time. Help callers to // print reasonable errors. status = SetTag(status, Tag::kAddressInUse); } - close(listen_sockfd_); - listen_sockfd_ = kInvalidFd; - + Close(&socket_info_->listen_sock); return status; } LOG_DEBUG("Listen"); - listen(listen_sockfd_, 1); + result = listen(socket_info_->listen_sock, 1); + if (result == kSocketError) { + int err = GetLastError(); + Close(&socket_info_->listen_sock); + return MakeStatus("Listening to socket failed: %s", GetErrorStr(err)); + } + return absl::OkStatus(); } void ServerSocket::StopListening() { - if (listen_sockfd_ != kInvalidFd) { - close(listen_sockfd_); - listen_sockfd_ = kInvalidFd; - } - + Close(&socket_info_->listen_sock); LOG_INFO("Stopped listening."); } absl::Status ServerSocket::WaitForConnection() { - if (conn_sockfd_ != kInvalidFd) { + if (socket_info_->conn_sock != kInvalidSocket) { return MakeStatus("Already connected"); } - sockaddr_in cli_addr; - socklen_t cli_len = sizeof(cli_addr); - conn_sockfd_ = accept(listen_sockfd_, (struct sockaddr*)&cli_addr, &cli_len); - if (conn_sockfd_ < 0) { - conn_sockfd_ = kInvalidFd; - return MakeStatus("accept() failed: %s", strerror(errno)); + socket_info_->conn_sock = accept(socket_info_->listen_sock, nullptr, nullptr); + if (socket_info_->conn_sock == kInvalidSocket) { + return MakeStatus("Accepting connection failed: %s", GetLastErrorStr()); } LOG_DEBUG("Client connected"); @@ -124,39 +190,36 @@ absl::Status ServerSocket::WaitForConnection() { } void ServerSocket::Disconnect() { - if (conn_sockfd_ != kInvalidFd) { - close(conn_sockfd_); - conn_sockfd_ = kInvalidFd; - } - + Close(&socket_info_->conn_sock); LOG_INFO("Disconnected"); } absl::Status ServerSocket::ShutdownSendingEnd() { - int result = shutdown(conn_sockfd_, SHUT_WR); - if (result != 0) { - return MakeStatus("shutdown() failed: %s", strerror(errno)); + int result = shutdown(socket_info_->conn_sock, kSendingEnd); + if (result == kSocketError) { + return MakeStatus("Socket shutdown failed: %s", GetLastErrorStr()); } return absl::OkStatus(); } absl::Status ServerSocket::Send(const void* buffer, size_t size) { - const uint8_t* curr_ptr = reinterpret_cast(buffer); - ssize_t bytes_left = size; + const char* curr_ptr = reinterpret_cast(buffer); + assert(size <= INT_MAX); + int bytes_left = static_cast(size); while (bytes_left > 0) { - ssize_t bytes_written = - HANDLE_EINTR(send(conn_sockfd_, curr_ptr, bytes_left, /*flags*/ 0)); + int bytes_written = HANDLE_EINTR( + send(socket_info_->conn_sock, curr_ptr, bytes_left, /*flags*/ 0)); if (bytes_written < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { + const int err = GetLastError(); + if (err == kErrAgain || err == kErrWouldBlock) { // Shouldn't happen as the socket should be blocking. LOG_DEBUG("Socket would block"); continue; } - return MakeStatus("write() to fd %i failed: %s", conn_sockfd_, - strerror(errno)); + return MakeStatus("Sending to socket failed: %s", GetErrorStr(err)); } bytes_left -= bytes_written; @@ -173,21 +236,22 @@ absl::Status ServerSocket::Receive(void* buffer, size_t size, return absl::OkStatus(); } - uint8_t* curr_ptr = reinterpret_cast(buffer); - ssize_t bytes_left = size; + char* curr_ptr = static_cast(buffer); + assert(size <= INT_MAX); + int bytes_left = size; while (bytes_left > 0) { - ssize_t bytes_read = - HANDLE_EINTR(recv(conn_sockfd_, curr_ptr, bytes_left, /*flags*/ 0)); + int bytes_read = HANDLE_EINTR( + recv(socket_info_->conn_sock, curr_ptr, bytes_left, /*flags*/ 0)); if (bytes_read < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { + const int err = GetLastError(); + if (err == kErrAgain || err == kErrWouldBlock) { // Shouldn't happen as the socket should be blocking. LOG_DEBUG("Socket would block"); continue; } - return MakeStatus("recv() from fd %i failed: %s", conn_sockfd_, - strerror(errno)); + return MakeStatus("Receiving from socket failed: %s", GetErrorStr(err)); } bytes_left -= bytes_read; @@ -196,7 +260,7 @@ absl::Status ServerSocket::Receive(void* buffer, size_t size, if (bytes_read == 0) { // EOF. Make sure we're not in the middle of a message. - if (bytes_left < static_cast(size)) { + if (bytes_left < static_cast(size)) { return MakeStatus("EOF after partial read"); } diff --git a/cdc_rsync_server/server_socket.h b/cdc_rsync_server/server_socket.h index 063ed56..636b949 100644 --- a/cdc_rsync_server/server_socket.h +++ b/cdc_rsync_server/server_socket.h @@ -50,11 +50,7 @@ class ServerSocket : public Socket { size_t* bytes_received) override; private: - // Listening socket file descriptor (where new connections are accepted). - int listen_sockfd_; - - // Connection socket file descriptor (where data is sent to/received from). - int conn_sockfd_; + std::unique_ptr socket_info_; }; } // namespace cdc_ft