[common] Add support for ClientSocket on Linux (#97)

Uses the abstractions written for ServerSocket in ClientSocket, so
that it builds on Linux. Also adds a method to poll for connections
and uses that in cdc_rsync. Similar code will be used in cdc_fuse_fs
to wait for a connection in a future CL.
This commit is contained in:
Lutz Justen
2023-05-27 21:18:13 +02:00
committed by GitHub
parent 26ff93489e
commit 678ee0ffaf
6 changed files with 185 additions and 143 deletions

View File

@@ -300,24 +300,9 @@ absl::Status CdcRsyncClient::StartServer(const ServerArch& arch) {
"Failed to start cdc_rsync_server process"); "Failed to start cdc_rsync_server process");
} }
// Poll the connection to the socket with port |local_port| until the port // Wait for port forwarding to be up.
// forwarding connection is set up. RETURN_IF_ERROR(ClientSocket::WaitForConnection(
timeout_timer.Reset(); local_port, absl::Seconds(options_.connection_timeout_sec)));
for (;;) {
assert(local_port != 0);
status = socket_.Connect(local_port);
if (status.ok()) {
break;
}
if (timeout_timer.ElapsedSeconds() > options_.connection_timeout_sec) {
return SetTag(
absl::DeadlineExceededError("Timeout while connecting to server"),
Tag::kConnectionTimeout);
}
Util::Sleep(10);
}
server_process_ = std::move(srv_process); server_process_ = std::move(srv_process);
port_forwarding_process_ = std::move(fwd_process); port_forwarding_process_ = std::move(fwd_process);

View File

@@ -62,7 +62,10 @@ cc_library(
cc_library( cc_library(
name = "client_socket", name = "client_socket",
srcs = ["client_socket.cc"], srcs = [
"client_socket.cc",
"socket_internal.h",
],
hdrs = ["client_socket.h"], hdrs = ["client_socket.h"],
linkopts = select({ linkopts = select({
"//tools:windows": [ "//tools:windows": [
@@ -70,7 +73,6 @@ cc_library(
], ],
"//conditions:default": [], "//conditions:default": [],
}), }),
target_compatible_with = ["@platforms//os:windows"],
deps = [ deps = [
":log", ":log",
":socket", ":socket",
@@ -473,7 +475,10 @@ cc_test(
cc_library( cc_library(
name = "server_socket", name = "server_socket",
srcs = ["server_socket.cc"], srcs = [
"server_socket.cc",
"socket_internal.h",
],
hdrs = ["server_socket.h"], hdrs = ["server_socket.h"],
linkopts = select({ linkopts = select({
"//tools:windows": [ "//tools:windows": [

View File

@@ -14,13 +14,12 @@
#include "common/client_socket.h" #include "common/client_socket.h"
#include <winsock2.h>
#include <ws2tcpip.h>
#include <cassert> #include <cassert>
#include "common/log.h" #include "common/log.h"
#include "common/socket_internal.h"
#include "common/status.h" #include "common/status.h"
#include "common/stopwatch.h"
#include "common/util.h" #include "common/util.h"
namespace cdc_ft { namespace cdc_ft {
@@ -29,9 +28,9 @@ namespace {
// Creates a status with the given |message| and the last WSA error. // Creates a status with the given |message| and the last WSA error.
// Assigns Tag::kSocketEof for WSAECONNRESET errors. // Assigns Tag::kSocketEof for WSAECONNRESET errors.
absl::Status MakeSocketStatus(const char* message) { absl::Status MakeSocketStatus(const char* message) {
const int err = WSAGetLastError(); const int err = GetLastError();
absl::Status status = MakeStatus("%s: %s", message, Util::GetWin32Error(err)); absl::Status status = MakeStatus("%s: %s", message, GetErrorStr(err));
if (err == WSAECONNRESET) { if (err == kErrConnReset) {
status = SetTag(status, Tag::kSocketEof); status = SetTag(status, Tag::kSocketEof);
} }
return status; return status;
@@ -40,18 +39,39 @@ absl::Status MakeSocketStatus(const char* message) {
} // namespace } // namespace
struct ClientSocketInfo { struct ClientSocketInfo {
SOCKET socket; SocketType socket;
ClientSocketInfo() : socket(INVALID_SOCKET) {} ClientSocketInfo() : socket(kInvalidSocket) {}
}; };
ClientSocket::ClientSocket() = default; ClientSocket::ClientSocket() = default;
ClientSocket::~ClientSocket() { Disconnect(); } ClientSocket::~ClientSocket() { Disconnect(); }
// static
absl::Status ClientSocket::WaitForConnection(int port, absl::Duration timeout) {
assert(port != 0);
Stopwatch timeout_timer;
ClientSocket socket;
for (;;) {
absl::Status status = socket.Connect(port);
if (status.ok()) {
return absl::OkStatus();
}
if (timeout_timer.Elapsed() > timeout) {
return SetTag(
absl::DeadlineExceededError("Timeout while connecting to server"),
Tag::kConnectionTimeout);
}
Util::Sleep(50);
}
}
absl::Status ClientSocket::Connect(int port) { absl::Status ClientSocket::Connect(int port) {
addrinfo hints; addrinfo hints;
ZeroMemory(&hints, sizeof(hints)); memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET; hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM; hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP; hints.ai_protocol = IPPROTO_TCP;
@@ -63,26 +83,26 @@ absl::Status ClientSocket::Connect(int port) {
if (result != 0) { if (result != 0) {
return MakeStatus("getaddrinfo() failed: %i", result); return MakeStatus("getaddrinfo() failed: %i", result);
} }
AddrInfoReleaser releaser(addr_infos);
socket_info_ = std::make_unique<ClientSocketInfo>(); socket_info_ = std::make_unique<ClientSocketInfo>();
int count = 0; int count = 0;
for (addrinfo* curr = addr_infos; curr; curr = curr->ai_next, count++) { for (addrinfo* curr = addr_infos; curr; curr = curr->ai_next, count++) {
socket_info_->socket = socket_info_->socket =
socket(curr->ai_family, curr->ai_socktype, curr->ai_protocol); socket(curr->ai_family, curr->ai_socktype, curr->ai_protocol);
if (socket_info_->socket == INVALID_SOCKET) { if (socket_info_->socket == kInvalidSocket) {
LOG_DEBUG("socket() failed for addr_info %i: %s", count, LOG_DEBUG("socket() failed for addr_info %i: %s", count,
Util::GetWin32Error(WSAGetLastError())); GetLastErrorStr());
continue; continue;
} }
// Connect to server. // Connect to server.
result = connect(socket_info_->socket, curr->ai_addr, result = connect(socket_info_->socket, curr->ai_addr,
static_cast<int>(curr->ai_addrlen)); static_cast<int>(curr->ai_addrlen));
if (result == SOCKET_ERROR) { if (result == kSocketError) {
LOG_DEBUG("connect() failed for addr_info %i: %s", count, LOG_DEBUG("connect() failed for addr_info %i: %s", count,
Util::GetWin32Error(WSAGetLastError())); GetLastErrorStr());
closesocket(socket_info_->socket); Close(&socket_info_->socket);
socket_info_->socket = INVALID_SOCKET;
continue; continue;
} }
@@ -90,9 +110,7 @@ absl::Status ClientSocket::Connect(int port) {
break; break;
} }
freeaddrinfo(addr_infos); if (socket_info_->socket == kInvalidSocket) {
if (socket_info_->socket == INVALID_SOCKET) {
socket_info_.reset(); socket_info_.reset();
return MakeStatus("Unable to connect to port %i", port); return MakeStatus("Unable to connect to port %i", port);
} }
@@ -106,18 +124,15 @@ void ClientSocket::Disconnect() {
return; return;
} }
if (socket_info_->socket != INVALID_SOCKET) { Close(&socket_info_->socket);
closesocket(socket_info_->socket);
socket_info_->socket = INVALID_SOCKET;
}
socket_info_.reset(); socket_info_.reset();
} }
absl::Status ClientSocket::Send(const void* buffer, size_t size) { absl::Status ClientSocket::Send(const void* buffer, size_t size) {
int result = send(socket_info_->socket, static_cast<const char*>(buffer), int result =
static_cast<int>(size), /*flags */ 0); HANDLE_EINTR(send(socket_info_->socket, static_cast<const char*>(buffer),
if (result == SOCKET_ERROR) { static_cast<int>(size), /*flags */ 0));
if (result == kSocketError) {
return MakeSocketStatus("send() failed"); return MakeSocketStatus("send() failed");
} }
@@ -133,9 +148,10 @@ absl::Status ClientSocket::Receive(void* buffer, size_t size,
} }
int flags = allow_partial_read ? 0 : MSG_WAITALL; int flags = allow_partial_read ? 0 : MSG_WAITALL;
int bytes_read = recv(socket_info_->socket, static_cast<char*>(buffer), int bytes_read =
static_cast<int>(size), flags); HANDLE_EINTR(recv(socket_info_->socket, static_cast<char*>(buffer),
if (bytes_read == SOCKET_ERROR) { static_cast<int>(size), flags));
if (bytes_read == kSocketError) {
return MakeSocketStatus("recv() failed"); return MakeSocketStatus("recv() failed");
} }
@@ -154,9 +170,9 @@ absl::Status ClientSocket::Receive(void* buffer, size_t size,
} }
absl::Status ClientSocket::ShutdownSendingEnd() { absl::Status ClientSocket::ShutdownSendingEnd() {
int result = shutdown(socket_info_->socket, SD_SEND); int result = shutdown(socket_info_->socket, kSendingEnd);
if (result == SOCKET_ERROR) { if (result == kSocketError) {
return MakeSocketStatus("shutdown() failed"); return MakeStatus("Socket shutdown failed: %s", GetLastErrorStr());
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@@ -29,6 +29,10 @@ class ClientSocket : public Socket {
ClientSocket(); ClientSocket();
~ClientSocket(); ~ClientSocket();
// Polls until a connection to |port| succeeds.
// Returns TimeoutError
static absl::Status WaitForConnection(int port, absl::Duration timeout);
// Connects to localhost on |port|. // Connects to localhost on |port|.
absl::Status Connect(int port); absl::Status Connect(int port);

View File

@@ -15,99 +15,10 @@
#include "common/server_socket.h" #include "common/server_socket.h"
#include "common/log.h" #include "common/log.h"
#include "common/platform.h" #include "common/socket_internal.h"
#include "common/status.h" #include "common/status.h"
#include "common/util.h"
#if PLATFORM_WINDOWS
#include <winsock2.h>
#include <ws2tcpip.h>
#elif PLATFORM_LINUX
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <cerrno>
#endif
namespace cdc_ft { namespace cdc_ft {
namespace {
#if PLATFORM_WINDOWS
using SocketType = SOCKET;
using SockAddrType = SOCKADDR;
constexpr SocketType kInvalidSocket = INVALID_SOCKET;
constexpr int kSocketError = SOCKET_ERROR;
constexpr int kSendingEnd = SD_SEND;
constexpr int kErrAgain = WSAEWOULDBLOCK; // There's no EAGAIN on Windows.
constexpr int kErrWouldBlock = WSAEWOULDBLOCK;
constexpr int kErrAddrInUse = WSAEADDRINUSE;
int GetLastError() { return WSAGetLastError(); }
std::string GetErrorStr(int err) { return Util::GetWin32Error(err); }
void Close(SocketType* socket) {
if (*socket != kInvalidSocket) {
closesocket(*socket);
*socket = kInvalidSocket;
}
}
// Not necessary on Windows.
#define HANDLE_EINTR(x) (x)
#elif PLATFORM_LINUX
using SocketType = int;
using SockAddrType = sockaddr;
constexpr SocketType kInvalidSocket = -1;
constexpr int kSocketError = -1;
constexpr int kSendingEnd = SHUT_WR;
constexpr int kErrAgain = EAGAIN;
constexpr int kErrWouldBlock = EWOULDBLOCK;
constexpr int kErrAddrInUse = EADDRINUSE;
int GetLastError() { return errno; }
std::string GetErrorStr(int err) { return strerror(err); }
void Close(SocketType* socket) {
if (*socket != kInvalidSocket) {
close(*socket);
*socket = kInvalidSocket;
}
}
// Keep re-evaluating the expression |x| while it returns EINTR.
#define HANDLE_EINTR(x) \
({ \
decltype(x) eintr_wrapper_result; \
do { \
eintr_wrapper_result = (x); \
} while (eintr_wrapper_result == -1 && errno == EINTR); \
eintr_wrapper_result; \
})
#endif
std::string GetLastErrorStr() { return GetErrorStr(GetLastError()); }
class AddrInfoReleaser {
public:
AddrInfoReleaser(addrinfo* addr_infos) : addr_infos_(addr_infos) {}
~AddrInfoReleaser() { freeaddrinfo(addr_infos_); }
private:
addrinfo* addr_infos_;
};
} // namespace
struct ServerSocketInfo { struct ServerSocketInfo {
// Listening socket file descriptor (where new connections are accepted). // Listening socket file descriptor (where new connections are accepted).

121
common/socket_internal.h Normal file
View File

@@ -0,0 +1,121 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef COMMON_SOCKET_INTERNAL_H_
#include "common/platform.h"
#include "common/util.h"
#if PLATFORM_WINDOWS
#include <winsock2.h>
#include <ws2tcpip.h>
#elif PLATFORM_LINUX
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <cerrno>
#endif
namespace cdc_ft {
namespace {
// Platform-specific abstractions for socket classes.
#if PLATFORM_WINDOWS
using SocketType = SOCKET;
using SockAddrType = SOCKADDR;
constexpr SocketType kInvalidSocket = INVALID_SOCKET;
constexpr int kSocketError = SOCKET_ERROR;
constexpr int kSendingEnd = SD_SEND;
constexpr int kErrAgain = WSAEWOULDBLOCK; // There's no EAGAIN on Windows.
constexpr int kErrWouldBlock = WSAEWOULDBLOCK;
constexpr int kErrAddrInUse = WSAEADDRINUSE;
constexpr int kErrConnReset = WSAECONNRESET;
int GetLastError() { return WSAGetLastError(); }
std::string GetErrorStr(int err) { return Util::GetWin32Error(err); }
void Close(SocketType* socket) {
if (*socket != kInvalidSocket) {
closesocket(*socket);
*socket = kInvalidSocket;
}
}
// Not necessary on Windows.
#define HANDLE_EINTR(x) (x)
#elif PLATFORM_LINUX
using SocketType = int;
using SockAddrType = sockaddr;
constexpr SocketType kInvalidSocket = -1;
constexpr int kSocketError = -1;
constexpr int kSendingEnd = SHUT_WR;
constexpr int kErrAgain = EAGAIN;
constexpr int kErrWouldBlock = EWOULDBLOCK;
constexpr int kErrAddrInUse = EADDRINUSE;
constexpr int kErrConnReset = ECONNRESET;
int GetLastError() { return errno; }
std::string GetErrorStr(int err) { return strerror(err); }
void Close(SocketType* socket) {
if (*socket != kInvalidSocket) {
close(*socket);
*socket = kInvalidSocket;
}
}
// Keep re-evaluating the expression |x| while it returns EINTR.
#define HANDLE_EINTR(x) \
({ \
decltype(x) eintr_wrapper_result; \
do { \
eintr_wrapper_result = (x); \
} while (eintr_wrapper_result == -1 && errno == EINTR); \
eintr_wrapper_result; \
})
#endif
std::string GetLastErrorStr() { return GetErrorStr(GetLastError()); }
class AddrInfoReleaser {
public:
AddrInfoReleaser(addrinfo* addr_infos) : addr_infos_(addr_infos) {}
~AddrInfoReleaser() { freeaddrinfo(addr_infos_); }
private:
addrinfo* addr_infos_;
};
} // namespace
} // namespace cdc_ft
#endif // COMMON_SOCKET_INTERNAL_H_