diff --git a/cdc_rsync/cdc_rsync_client.cc b/cdc_rsync/cdc_rsync_client.cc index 56ef71a..2e9f1c2 100644 --- a/cdc_rsync/cdc_rsync_client.cc +++ b/cdc_rsync/cdc_rsync_client.cc @@ -300,24 +300,9 @@ absl::Status CdcRsyncClient::StartServer(const ServerArch& arch) { "Failed to start cdc_rsync_server process"); } - // Poll the connection to the socket with port |local_port| until the port - // forwarding connection is set up. - timeout_timer.Reset(); - for (;;) { - assert(local_port != 0); - status = socket_.Connect(local_port); - if (status.ok()) { - break; - } - - if (timeout_timer.ElapsedSeconds() > options_.connection_timeout_sec) { - return SetTag( - absl::DeadlineExceededError("Timeout while connecting to server"), - Tag::kConnectionTimeout); - } - - Util::Sleep(10); - } + // Wait for port forwarding to be up. + RETURN_IF_ERROR(ClientSocket::WaitForConnection( + local_port, absl::Seconds(options_.connection_timeout_sec))); server_process_ = std::move(srv_process); port_forwarding_process_ = std::move(fwd_process); diff --git a/common/BUILD b/common/BUILD index 0de62eb..640e7ba 100644 --- a/common/BUILD +++ b/common/BUILD @@ -62,7 +62,10 @@ cc_library( cc_library( name = "client_socket", - srcs = ["client_socket.cc"], + srcs = [ + "client_socket.cc", + "socket_internal.h", + ], hdrs = ["client_socket.h"], linkopts = select({ "//tools:windows": [ @@ -70,7 +73,6 @@ cc_library( ], "//conditions:default": [], }), - target_compatible_with = ["@platforms//os:windows"], deps = [ ":log", ":socket", @@ -473,7 +475,10 @@ cc_test( cc_library( name = "server_socket", - srcs = ["server_socket.cc"], + srcs = [ + "server_socket.cc", + "socket_internal.h", + ], hdrs = ["server_socket.h"], linkopts = select({ "//tools:windows": [ diff --git a/common/client_socket.cc b/common/client_socket.cc index b4ce172..eee21fc 100644 --- a/common/client_socket.cc +++ b/common/client_socket.cc @@ -14,13 +14,12 @@ #include "common/client_socket.h" -#include -#include - #include #include "common/log.h" +#include "common/socket_internal.h" #include "common/status.h" +#include "common/stopwatch.h" #include "common/util.h" namespace cdc_ft { @@ -29,9 +28,9 @@ namespace { // Creates a status with the given |message| and the last WSA error. // Assigns Tag::kSocketEof for WSAECONNRESET errors. absl::Status MakeSocketStatus(const char* message) { - const int err = WSAGetLastError(); - absl::Status status = MakeStatus("%s: %s", message, Util::GetWin32Error(err)); - if (err == WSAECONNRESET) { + const int err = GetLastError(); + absl::Status status = MakeStatus("%s: %s", message, GetErrorStr(err)); + if (err == kErrConnReset) { status = SetTag(status, Tag::kSocketEof); } return status; @@ -40,18 +39,39 @@ absl::Status MakeSocketStatus(const char* message) { } // namespace struct ClientSocketInfo { - SOCKET socket; + SocketType socket; - ClientSocketInfo() : socket(INVALID_SOCKET) {} + ClientSocketInfo() : socket(kInvalidSocket) {} }; ClientSocket::ClientSocket() = default; ClientSocket::~ClientSocket() { Disconnect(); } +// static +absl::Status ClientSocket::WaitForConnection(int port, absl::Duration timeout) { + assert(port != 0); + Stopwatch timeout_timer; + ClientSocket socket; + for (;;) { + absl::Status status = socket.Connect(port); + if (status.ok()) { + return absl::OkStatus(); + } + + if (timeout_timer.Elapsed() > timeout) { + return SetTag( + absl::DeadlineExceededError("Timeout while connecting to server"), + Tag::kConnectionTimeout); + } + + Util::Sleep(50); + } +} + absl::Status ClientSocket::Connect(int port) { addrinfo hints; - ZeroMemory(&hints, sizeof(hints)); + memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; @@ -63,26 +83,26 @@ absl::Status ClientSocket::Connect(int port) { if (result != 0) { return MakeStatus("getaddrinfo() failed: %i", result); } + AddrInfoReleaser releaser(addr_infos); socket_info_ = std::make_unique(); int count = 0; for (addrinfo* curr = addr_infos; curr; curr = curr->ai_next, count++) { socket_info_->socket = socket(curr->ai_family, curr->ai_socktype, curr->ai_protocol); - if (socket_info_->socket == INVALID_SOCKET) { + if (socket_info_->socket == kInvalidSocket) { LOG_DEBUG("socket() failed for addr_info %i: %s", count, - Util::GetWin32Error(WSAGetLastError())); + GetLastErrorStr()); continue; } // Connect to server. result = connect(socket_info_->socket, curr->ai_addr, static_cast(curr->ai_addrlen)); - if (result == SOCKET_ERROR) { + if (result == kSocketError) { LOG_DEBUG("connect() failed for addr_info %i: %s", count, - Util::GetWin32Error(WSAGetLastError())); - closesocket(socket_info_->socket); - socket_info_->socket = INVALID_SOCKET; + GetLastErrorStr()); + Close(&socket_info_->socket); continue; } @@ -90,9 +110,7 @@ absl::Status ClientSocket::Connect(int port) { break; } - freeaddrinfo(addr_infos); - - if (socket_info_->socket == INVALID_SOCKET) { + if (socket_info_->socket == kInvalidSocket) { socket_info_.reset(); return MakeStatus("Unable to connect to port %i", port); } @@ -106,18 +124,15 @@ void ClientSocket::Disconnect() { return; } - if (socket_info_->socket != INVALID_SOCKET) { - closesocket(socket_info_->socket); - socket_info_->socket = INVALID_SOCKET; - } - + Close(&socket_info_->socket); socket_info_.reset(); } absl::Status ClientSocket::Send(const void* buffer, size_t size) { - int result = send(socket_info_->socket, static_cast(buffer), - static_cast(size), /*flags */ 0); - if (result == SOCKET_ERROR) { + int result = + HANDLE_EINTR(send(socket_info_->socket, static_cast(buffer), + static_cast(size), /*flags */ 0)); + if (result == kSocketError) { return MakeSocketStatus("send() failed"); } @@ -133,9 +148,10 @@ absl::Status ClientSocket::Receive(void* buffer, size_t size, } int flags = allow_partial_read ? 0 : MSG_WAITALL; - int bytes_read = recv(socket_info_->socket, static_cast(buffer), - static_cast(size), flags); - if (bytes_read == SOCKET_ERROR) { + int bytes_read = + HANDLE_EINTR(recv(socket_info_->socket, static_cast(buffer), + static_cast(size), flags)); + if (bytes_read == kSocketError) { return MakeSocketStatus("recv() failed"); } @@ -154,9 +170,9 @@ absl::Status ClientSocket::Receive(void* buffer, size_t size, } absl::Status ClientSocket::ShutdownSendingEnd() { - int result = shutdown(socket_info_->socket, SD_SEND); - if (result == SOCKET_ERROR) { - return MakeSocketStatus("shutdown() failed"); + int result = shutdown(socket_info_->socket, kSendingEnd); + if (result == kSocketError) { + return MakeStatus("Socket shutdown failed: %s", GetLastErrorStr()); } return absl::OkStatus(); diff --git a/common/client_socket.h b/common/client_socket.h index ae3bc1d..9f46b45 100644 --- a/common/client_socket.h +++ b/common/client_socket.h @@ -29,6 +29,10 @@ class ClientSocket : public Socket { ClientSocket(); ~ClientSocket(); + // Polls until a connection to |port| succeeds. + // Returns TimeoutError + static absl::Status WaitForConnection(int port, absl::Duration timeout); + // Connects to localhost on |port|. absl::Status Connect(int port); diff --git a/common/server_socket.cc b/common/server_socket.cc index 940a68d..d17593e 100644 --- a/common/server_socket.cc +++ b/common/server_socket.cc @@ -15,99 +15,10 @@ #include "common/server_socket.h" #include "common/log.h" -#include "common/platform.h" +#include "common/socket_internal.h" #include "common/status.h" -#include "common/util.h" - -#if PLATFORM_WINDOWS - -#include -#include - -#elif PLATFORM_LINUX - -#include -#include -#include -#include -#include - -#include - -#endif namespace cdc_ft { -namespace { - -#if PLATFORM_WINDOWS - -using SocketType = SOCKET; -using SockAddrType = SOCKADDR; -constexpr SocketType kInvalidSocket = INVALID_SOCKET; -constexpr int kSocketError = SOCKET_ERROR; -constexpr int kSendingEnd = SD_SEND; - -constexpr int kErrAgain = WSAEWOULDBLOCK; // There's no EAGAIN on Windows. -constexpr int kErrWouldBlock = WSAEWOULDBLOCK; -constexpr int kErrAddrInUse = WSAEADDRINUSE; - -int GetLastError() { return WSAGetLastError(); } -std::string GetErrorStr(int err) { return Util::GetWin32Error(err); } -void Close(SocketType* socket) { - if (*socket != kInvalidSocket) { - closesocket(*socket); - *socket = kInvalidSocket; - } -} - -// Not necessary on Windows. -#define HANDLE_EINTR(x) (x) - -#elif PLATFORM_LINUX - -using SocketType = int; -using SockAddrType = sockaddr; -constexpr SocketType kInvalidSocket = -1; -constexpr int kSocketError = -1; -constexpr int kSendingEnd = SHUT_WR; - -constexpr int kErrAgain = EAGAIN; -constexpr int kErrWouldBlock = EWOULDBLOCK; -constexpr int kErrAddrInUse = EADDRINUSE; - -int GetLastError() { return errno; } -std::string GetErrorStr(int err) { return strerror(err); } -void Close(SocketType* socket) { - if (*socket != kInvalidSocket) { - close(*socket); - *socket = kInvalidSocket; - } -} - -// Keep re-evaluating the expression |x| while it returns EINTR. -#define HANDLE_EINTR(x) \ - ({ \ - decltype(x) eintr_wrapper_result; \ - do { \ - eintr_wrapper_result = (x); \ - } while (eintr_wrapper_result == -1 && errno == EINTR); \ - eintr_wrapper_result; \ - }) - -#endif - -std::string GetLastErrorStr() { return GetErrorStr(GetLastError()); } - -class AddrInfoReleaser { - public: - AddrInfoReleaser(addrinfo* addr_infos) : addr_infos_(addr_infos) {} - ~AddrInfoReleaser() { freeaddrinfo(addr_infos_); } - - private: - addrinfo* addr_infos_; -}; - -} // namespace struct ServerSocketInfo { // Listening socket file descriptor (where new connections are accepted). diff --git a/common/socket_internal.h b/common/socket_internal.h new file mode 100644 index 0000000..043b210 --- /dev/null +++ b/common/socket_internal.h @@ -0,0 +1,121 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_SOCKET_INTERNAL_H_ + +#include "common/platform.h" +#include "common/util.h" + +#if PLATFORM_WINDOWS + +#include +#include + +#elif PLATFORM_LINUX + +#include +#include +#include +#include +#include + +#include + +#endif + +namespace cdc_ft { +namespace { + +// Platform-specific abstractions for socket classes. + +#if PLATFORM_WINDOWS + +using SocketType = SOCKET; +using SockAddrType = SOCKADDR; +constexpr SocketType kInvalidSocket = INVALID_SOCKET; +constexpr int kSocketError = SOCKET_ERROR; +constexpr int kSendingEnd = SD_SEND; + +constexpr int kErrAgain = WSAEWOULDBLOCK; // There's no EAGAIN on Windows. +constexpr int kErrWouldBlock = WSAEWOULDBLOCK; +constexpr int kErrAddrInUse = WSAEADDRINUSE; +constexpr int kErrConnReset = WSAECONNRESET; + +int GetLastError() { return WSAGetLastError(); } + +std::string GetErrorStr(int err) { return Util::GetWin32Error(err); } + +void Close(SocketType* socket) { + if (*socket != kInvalidSocket) { + closesocket(*socket); + *socket = kInvalidSocket; + } +} + +// Not necessary on Windows. +#define HANDLE_EINTR(x) (x) + +#elif PLATFORM_LINUX + +using SocketType = int; +using SockAddrType = sockaddr; +constexpr SocketType kInvalidSocket = -1; +constexpr int kSocketError = -1; +constexpr int kSendingEnd = SHUT_WR; + +constexpr int kErrAgain = EAGAIN; +constexpr int kErrWouldBlock = EWOULDBLOCK; +constexpr int kErrAddrInUse = EADDRINUSE; +constexpr int kErrConnReset = ECONNRESET; + +int GetLastError() { return errno; } + +std::string GetErrorStr(int err) { return strerror(err); } + +void Close(SocketType* socket) { + if (*socket != kInvalidSocket) { + close(*socket); + *socket = kInvalidSocket; + } +} + +// Keep re-evaluating the expression |x| while it returns EINTR. +#define HANDLE_EINTR(x) \ + ({ \ + decltype(x) eintr_wrapper_result; \ + do { \ + eintr_wrapper_result = (x); \ + } while (eintr_wrapper_result == -1 && errno == EINTR); \ + eintr_wrapper_result; \ + }) + +#endif + +std::string GetLastErrorStr() { return GetErrorStr(GetLastError()); } + +class AddrInfoReleaser { + public: + AddrInfoReleaser(addrinfo* addr_infos) : addr_infos_(addr_infos) {} + ~AddrInfoReleaser() { freeaddrinfo(addr_infos_); } + + private: + addrinfo* addr_infos_; +}; + +} // namespace +} // namespace cdc_ft + +#endif // COMMON_SOCKET_INTERNAL_H_