mirror of
https://github.com/nestriness/cdc-file-transfer.git
synced 2026-01-30 14:35:37 +02:00
[common] Add support for ClientSocket on Linux (#97)
Uses the abstractions written for ServerSocket in ClientSocket, so that it builds on Linux. Also adds a method to poll for connections and uses that in cdc_rsync. Similar code will be used in cdc_fuse_fs to wait for a connection in a future CL.
This commit is contained in:
11
common/BUILD
11
common/BUILD
@@ -62,7 +62,10 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "client_socket",
|
||||
srcs = ["client_socket.cc"],
|
||||
srcs = [
|
||||
"client_socket.cc",
|
||||
"socket_internal.h",
|
||||
],
|
||||
hdrs = ["client_socket.h"],
|
||||
linkopts = select({
|
||||
"//tools:windows": [
|
||||
@@ -70,7 +73,6 @@ cc_library(
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
target_compatible_with = ["@platforms//os:windows"],
|
||||
deps = [
|
||||
":log",
|
||||
":socket",
|
||||
@@ -473,7 +475,10 @@ cc_test(
|
||||
|
||||
cc_library(
|
||||
name = "server_socket",
|
||||
srcs = ["server_socket.cc"],
|
||||
srcs = [
|
||||
"server_socket.cc",
|
||||
"socket_internal.h",
|
||||
],
|
||||
hdrs = ["server_socket.h"],
|
||||
linkopts = select({
|
||||
"//tools:windows": [
|
||||
|
||||
@@ -14,13 +14,12 @@
|
||||
|
||||
#include "common/client_socket.h"
|
||||
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "common/log.h"
|
||||
#include "common/socket_internal.h"
|
||||
#include "common/status.h"
|
||||
#include "common/stopwatch.h"
|
||||
#include "common/util.h"
|
||||
|
||||
namespace cdc_ft {
|
||||
@@ -29,9 +28,9 @@ namespace {
|
||||
// Creates a status with the given |message| and the last WSA error.
|
||||
// Assigns Tag::kSocketEof for WSAECONNRESET errors.
|
||||
absl::Status MakeSocketStatus(const char* message) {
|
||||
const int err = WSAGetLastError();
|
||||
absl::Status status = MakeStatus("%s: %s", message, Util::GetWin32Error(err));
|
||||
if (err == WSAECONNRESET) {
|
||||
const int err = GetLastError();
|
||||
absl::Status status = MakeStatus("%s: %s", message, GetErrorStr(err));
|
||||
if (err == kErrConnReset) {
|
||||
status = SetTag(status, Tag::kSocketEof);
|
||||
}
|
||||
return status;
|
||||
@@ -40,18 +39,39 @@ absl::Status MakeSocketStatus(const char* message) {
|
||||
} // namespace
|
||||
|
||||
struct ClientSocketInfo {
|
||||
SOCKET socket;
|
||||
SocketType socket;
|
||||
|
||||
ClientSocketInfo() : socket(INVALID_SOCKET) {}
|
||||
ClientSocketInfo() : socket(kInvalidSocket) {}
|
||||
};
|
||||
|
||||
ClientSocket::ClientSocket() = default;
|
||||
|
||||
ClientSocket::~ClientSocket() { Disconnect(); }
|
||||
|
||||
// static
|
||||
absl::Status ClientSocket::WaitForConnection(int port, absl::Duration timeout) {
|
||||
assert(port != 0);
|
||||
Stopwatch timeout_timer;
|
||||
ClientSocket socket;
|
||||
for (;;) {
|
||||
absl::Status status = socket.Connect(port);
|
||||
if (status.ok()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
if (timeout_timer.Elapsed() > timeout) {
|
||||
return SetTag(
|
||||
absl::DeadlineExceededError("Timeout while connecting to server"),
|
||||
Tag::kConnectionTimeout);
|
||||
}
|
||||
|
||||
Util::Sleep(50);
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status ClientSocket::Connect(int port) {
|
||||
addrinfo hints;
|
||||
ZeroMemory(&hints, sizeof(hints));
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_INET;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
hints.ai_protocol = IPPROTO_TCP;
|
||||
@@ -63,26 +83,26 @@ absl::Status ClientSocket::Connect(int port) {
|
||||
if (result != 0) {
|
||||
return MakeStatus("getaddrinfo() failed: %i", result);
|
||||
}
|
||||
AddrInfoReleaser releaser(addr_infos);
|
||||
|
||||
socket_info_ = std::make_unique<ClientSocketInfo>();
|
||||
int count = 0;
|
||||
for (addrinfo* curr = addr_infos; curr; curr = curr->ai_next, count++) {
|
||||
socket_info_->socket =
|
||||
socket(curr->ai_family, curr->ai_socktype, curr->ai_protocol);
|
||||
if (socket_info_->socket == INVALID_SOCKET) {
|
||||
if (socket_info_->socket == kInvalidSocket) {
|
||||
LOG_DEBUG("socket() failed for addr_info %i: %s", count,
|
||||
Util::GetWin32Error(WSAGetLastError()));
|
||||
GetLastErrorStr());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Connect to server.
|
||||
result = connect(socket_info_->socket, curr->ai_addr,
|
||||
static_cast<int>(curr->ai_addrlen));
|
||||
if (result == SOCKET_ERROR) {
|
||||
if (result == kSocketError) {
|
||||
LOG_DEBUG("connect() failed for addr_info %i: %s", count,
|
||||
Util::GetWin32Error(WSAGetLastError()));
|
||||
closesocket(socket_info_->socket);
|
||||
socket_info_->socket = INVALID_SOCKET;
|
||||
GetLastErrorStr());
|
||||
Close(&socket_info_->socket);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -90,9 +110,7 @@ absl::Status ClientSocket::Connect(int port) {
|
||||
break;
|
||||
}
|
||||
|
||||
freeaddrinfo(addr_infos);
|
||||
|
||||
if (socket_info_->socket == INVALID_SOCKET) {
|
||||
if (socket_info_->socket == kInvalidSocket) {
|
||||
socket_info_.reset();
|
||||
return MakeStatus("Unable to connect to port %i", port);
|
||||
}
|
||||
@@ -106,18 +124,15 @@ void ClientSocket::Disconnect() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (socket_info_->socket != INVALID_SOCKET) {
|
||||
closesocket(socket_info_->socket);
|
||||
socket_info_->socket = INVALID_SOCKET;
|
||||
}
|
||||
|
||||
Close(&socket_info_->socket);
|
||||
socket_info_.reset();
|
||||
}
|
||||
|
||||
absl::Status ClientSocket::Send(const void* buffer, size_t size) {
|
||||
int result = send(socket_info_->socket, static_cast<const char*>(buffer),
|
||||
static_cast<int>(size), /*flags */ 0);
|
||||
if (result == SOCKET_ERROR) {
|
||||
int result =
|
||||
HANDLE_EINTR(send(socket_info_->socket, static_cast<const char*>(buffer),
|
||||
static_cast<int>(size), /*flags */ 0));
|
||||
if (result == kSocketError) {
|
||||
return MakeSocketStatus("send() failed");
|
||||
}
|
||||
|
||||
@@ -133,9 +148,10 @@ absl::Status ClientSocket::Receive(void* buffer, size_t size,
|
||||
}
|
||||
|
||||
int flags = allow_partial_read ? 0 : MSG_WAITALL;
|
||||
int bytes_read = recv(socket_info_->socket, static_cast<char*>(buffer),
|
||||
static_cast<int>(size), flags);
|
||||
if (bytes_read == SOCKET_ERROR) {
|
||||
int bytes_read =
|
||||
HANDLE_EINTR(recv(socket_info_->socket, static_cast<char*>(buffer),
|
||||
static_cast<int>(size), flags));
|
||||
if (bytes_read == kSocketError) {
|
||||
return MakeSocketStatus("recv() failed");
|
||||
}
|
||||
|
||||
@@ -154,9 +170,9 @@ absl::Status ClientSocket::Receive(void* buffer, size_t size,
|
||||
}
|
||||
|
||||
absl::Status ClientSocket::ShutdownSendingEnd() {
|
||||
int result = shutdown(socket_info_->socket, SD_SEND);
|
||||
if (result == SOCKET_ERROR) {
|
||||
return MakeSocketStatus("shutdown() failed");
|
||||
int result = shutdown(socket_info_->socket, kSendingEnd);
|
||||
if (result == kSocketError) {
|
||||
return MakeStatus("Socket shutdown failed: %s", GetLastErrorStr());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
||||
@@ -29,6 +29,10 @@ class ClientSocket : public Socket {
|
||||
ClientSocket();
|
||||
~ClientSocket();
|
||||
|
||||
// Polls until a connection to |port| succeeds.
|
||||
// Returns TimeoutError
|
||||
static absl::Status WaitForConnection(int port, absl::Duration timeout);
|
||||
|
||||
// Connects to localhost on |port|.
|
||||
absl::Status Connect(int port);
|
||||
|
||||
|
||||
@@ -15,99 +15,10 @@
|
||||
#include "common/server_socket.h"
|
||||
|
||||
#include "common/log.h"
|
||||
#include "common/platform.h"
|
||||
#include "common/socket_internal.h"
|
||||
#include "common/status.h"
|
||||
#include "common/util.h"
|
||||
|
||||
#if PLATFORM_WINDOWS
|
||||
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
|
||||
#elif PLATFORM_LINUX
|
||||
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cerrno>
|
||||
|
||||
#endif
|
||||
|
||||
namespace cdc_ft {
|
||||
namespace {
|
||||
|
||||
#if PLATFORM_WINDOWS
|
||||
|
||||
using SocketType = SOCKET;
|
||||
using SockAddrType = SOCKADDR;
|
||||
constexpr SocketType kInvalidSocket = INVALID_SOCKET;
|
||||
constexpr int kSocketError = SOCKET_ERROR;
|
||||
constexpr int kSendingEnd = SD_SEND;
|
||||
|
||||
constexpr int kErrAgain = WSAEWOULDBLOCK; // There's no EAGAIN on Windows.
|
||||
constexpr int kErrWouldBlock = WSAEWOULDBLOCK;
|
||||
constexpr int kErrAddrInUse = WSAEADDRINUSE;
|
||||
|
||||
int GetLastError() { return WSAGetLastError(); }
|
||||
std::string GetErrorStr(int err) { return Util::GetWin32Error(err); }
|
||||
void Close(SocketType* socket) {
|
||||
if (*socket != kInvalidSocket) {
|
||||
closesocket(*socket);
|
||||
*socket = kInvalidSocket;
|
||||
}
|
||||
}
|
||||
|
||||
// Not necessary on Windows.
|
||||
#define HANDLE_EINTR(x) (x)
|
||||
|
||||
#elif PLATFORM_LINUX
|
||||
|
||||
using SocketType = int;
|
||||
using SockAddrType = sockaddr;
|
||||
constexpr SocketType kInvalidSocket = -1;
|
||||
constexpr int kSocketError = -1;
|
||||
constexpr int kSendingEnd = SHUT_WR;
|
||||
|
||||
constexpr int kErrAgain = EAGAIN;
|
||||
constexpr int kErrWouldBlock = EWOULDBLOCK;
|
||||
constexpr int kErrAddrInUse = EADDRINUSE;
|
||||
|
||||
int GetLastError() { return errno; }
|
||||
std::string GetErrorStr(int err) { return strerror(err); }
|
||||
void Close(SocketType* socket) {
|
||||
if (*socket != kInvalidSocket) {
|
||||
close(*socket);
|
||||
*socket = kInvalidSocket;
|
||||
}
|
||||
}
|
||||
|
||||
// Keep re-evaluating the expression |x| while it returns EINTR.
|
||||
#define HANDLE_EINTR(x) \
|
||||
({ \
|
||||
decltype(x) eintr_wrapper_result; \
|
||||
do { \
|
||||
eintr_wrapper_result = (x); \
|
||||
} while (eintr_wrapper_result == -1 && errno == EINTR); \
|
||||
eintr_wrapper_result; \
|
||||
})
|
||||
|
||||
#endif
|
||||
|
||||
std::string GetLastErrorStr() { return GetErrorStr(GetLastError()); }
|
||||
|
||||
class AddrInfoReleaser {
|
||||
public:
|
||||
AddrInfoReleaser(addrinfo* addr_infos) : addr_infos_(addr_infos) {}
|
||||
~AddrInfoReleaser() { freeaddrinfo(addr_infos_); }
|
||||
|
||||
private:
|
||||
addrinfo* addr_infos_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
struct ServerSocketInfo {
|
||||
// Listening socket file descriptor (where new connections are accepted).
|
||||
|
||||
121
common/socket_internal.h
Normal file
121
common/socket_internal.h
Normal file
@@ -0,0 +1,121 @@
|
||||
/*
|
||||
* Copyright 2023 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef COMMON_SOCKET_INTERNAL_H_
|
||||
|
||||
#include "common/platform.h"
|
||||
#include "common/util.h"
|
||||
|
||||
#if PLATFORM_WINDOWS
|
||||
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
|
||||
#elif PLATFORM_LINUX
|
||||
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cerrno>
|
||||
|
||||
#endif
|
||||
|
||||
namespace cdc_ft {
|
||||
namespace {
|
||||
|
||||
// Platform-specific abstractions for socket classes.
|
||||
|
||||
#if PLATFORM_WINDOWS
|
||||
|
||||
using SocketType = SOCKET;
|
||||
using SockAddrType = SOCKADDR;
|
||||
constexpr SocketType kInvalidSocket = INVALID_SOCKET;
|
||||
constexpr int kSocketError = SOCKET_ERROR;
|
||||
constexpr int kSendingEnd = SD_SEND;
|
||||
|
||||
constexpr int kErrAgain = WSAEWOULDBLOCK; // There's no EAGAIN on Windows.
|
||||
constexpr int kErrWouldBlock = WSAEWOULDBLOCK;
|
||||
constexpr int kErrAddrInUse = WSAEADDRINUSE;
|
||||
constexpr int kErrConnReset = WSAECONNRESET;
|
||||
|
||||
int GetLastError() { return WSAGetLastError(); }
|
||||
|
||||
std::string GetErrorStr(int err) { return Util::GetWin32Error(err); }
|
||||
|
||||
void Close(SocketType* socket) {
|
||||
if (*socket != kInvalidSocket) {
|
||||
closesocket(*socket);
|
||||
*socket = kInvalidSocket;
|
||||
}
|
||||
}
|
||||
|
||||
// Not necessary on Windows.
|
||||
#define HANDLE_EINTR(x) (x)
|
||||
|
||||
#elif PLATFORM_LINUX
|
||||
|
||||
using SocketType = int;
|
||||
using SockAddrType = sockaddr;
|
||||
constexpr SocketType kInvalidSocket = -1;
|
||||
constexpr int kSocketError = -1;
|
||||
constexpr int kSendingEnd = SHUT_WR;
|
||||
|
||||
constexpr int kErrAgain = EAGAIN;
|
||||
constexpr int kErrWouldBlock = EWOULDBLOCK;
|
||||
constexpr int kErrAddrInUse = EADDRINUSE;
|
||||
constexpr int kErrConnReset = ECONNRESET;
|
||||
|
||||
int GetLastError() { return errno; }
|
||||
|
||||
std::string GetErrorStr(int err) { return strerror(err); }
|
||||
|
||||
void Close(SocketType* socket) {
|
||||
if (*socket != kInvalidSocket) {
|
||||
close(*socket);
|
||||
*socket = kInvalidSocket;
|
||||
}
|
||||
}
|
||||
|
||||
// Keep re-evaluating the expression |x| while it returns EINTR.
|
||||
#define HANDLE_EINTR(x) \
|
||||
({ \
|
||||
decltype(x) eintr_wrapper_result; \
|
||||
do { \
|
||||
eintr_wrapper_result = (x); \
|
||||
} while (eintr_wrapper_result == -1 && errno == EINTR); \
|
||||
eintr_wrapper_result; \
|
||||
})
|
||||
|
||||
#endif
|
||||
|
||||
std::string GetLastErrorStr() { return GetErrorStr(GetLastError()); }
|
||||
|
||||
class AddrInfoReleaser {
|
||||
public:
|
||||
AddrInfoReleaser(addrinfo* addr_infos) : addr_infos_(addr_infos) {}
|
||||
~AddrInfoReleaser() { freeaddrinfo(addr_infos_); }
|
||||
|
||||
private:
|
||||
addrinfo* addr_infos_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace cdc_ft
|
||||
|
||||
#endif // COMMON_SOCKET_INTERNAL_H_
|
||||
Reference in New Issue
Block a user