mirror of
https://github.com/nestriness/cdc-file-transfer.git
synced 2026-02-06 07:05:37 +02:00
[cdc_rsync] Add support for ServerSocket on Windows (#48)
Makes ServerSocket multi-platform, mainly by working around some small API differences. The code is largely the same, there should be no differences on Linux. Also moves WSAStartup() and WSACleanup() up to the Socket level as static methods because it's used by both ClientSocket and ServerSocket, and because it doesn't make sense to do that in the socket class as that would prevent one from using several sockets.
This commit is contained in:
@@ -14,20 +14,72 @@
|
||||
|
||||
#include "cdc_rsync_server/server_socket.h"
|
||||
|
||||
#include "common/log.h"
|
||||
#include "common/platform.h"
|
||||
#include "common/status.h"
|
||||
#include "common/util.h"
|
||||
|
||||
#if PLATFORM_WINDOWS
|
||||
|
||||
#include <winsock2.h>
|
||||
|
||||
#elif PLATFORM_LINUX
|
||||
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cerrno>
|
||||
|
||||
#include "common/log.h"
|
||||
#include "common/status.h"
|
||||
#endif
|
||||
|
||||
namespace cdc_ft {
|
||||
|
||||
namespace {
|
||||
|
||||
int kInvalidFd = -1;
|
||||
#if PLATFORM_WINDOWS
|
||||
|
||||
using SocketType = SOCKET;
|
||||
using SockAddrType = SOCKADDR;
|
||||
constexpr SocketType kInvalidSocket = INVALID_SOCKET;
|
||||
constexpr int kSocketError = SOCKET_ERROR;
|
||||
constexpr int kSendingEnd = SD_SEND;
|
||||
|
||||
constexpr int kErrAgain = WSAEWOULDBLOCK; // There's no EAGAIN on Windows.
|
||||
constexpr int kErrWouldBlock = WSAEWOULDBLOCK;
|
||||
constexpr int kErrAddrInUse = WSAEADDRINUSE;
|
||||
|
||||
int GetLastError() { return WSAGetLastError(); }
|
||||
std::string GetErrorStr(int err) { return Util::GetWin32Error(err); }
|
||||
void Close(SocketType* socket) {
|
||||
if (*socket != kInvalidSocket) {
|
||||
closesocket(*socket);
|
||||
*socket = kInvalidSocket;
|
||||
}
|
||||
}
|
||||
|
||||
// Not necessary on Windows.
|
||||
#define HANDLE_EINTR(x) (x)
|
||||
|
||||
#elif PLATFORM_LINUX
|
||||
|
||||
using SocketType = int;
|
||||
using SockAddrType = sockaddr;
|
||||
constexpr SocketType kInvalidSocket = -1;
|
||||
constexpr int kSocketError = -1;
|
||||
constexpr int kSendingEnd = SHUT_WR;
|
||||
|
||||
constexpr int kErrAgain = EAGAIN;
|
||||
constexpr int kErrWouldBlock = EWOULDBLOCK;
|
||||
constexpr int kErrAddrInUse = EADDRINUSE;
|
||||
|
||||
int GetLastError() { return errno; }
|
||||
std::string GetErrorStr(int err) { return strerror(err); }
|
||||
void Close(SocketType* socket) {
|
||||
if (*socket != kInvalidSocket) {
|
||||
close(*socket);
|
||||
*socket = kInvalidSocket;
|
||||
}
|
||||
}
|
||||
|
||||
// Keep re-evaluating the expression |x| while it returns EINTR.
|
||||
#define HANDLE_EINTR(x) \
|
||||
@@ -39,10 +91,22 @@ int kInvalidFd = -1;
|
||||
eintr_wrapper_result; \
|
||||
})
|
||||
|
||||
#endif
|
||||
|
||||
std::string GetLastErrorStr() { return GetErrorStr(GetLastError()); }
|
||||
|
||||
} // namespace
|
||||
|
||||
struct ServerSocketInfo {
|
||||
// Listening socket file descriptor (where new connections are accepted).
|
||||
SocketType listen_sock = kInvalidSocket;
|
||||
|
||||
// Connection socket file descriptor (where data is sent to/received from).
|
||||
SocketType conn_sock = kInvalidSocket;
|
||||
};
|
||||
|
||||
ServerSocket::ServerSocket()
|
||||
: Socket(), listen_sockfd_(kInvalidFd), conn_sockfd_(kInvalidFd) {}
|
||||
: Socket(), socket_info_(std::make_unique<ServerSocketInfo>()) {}
|
||||
|
||||
ServerSocket::~ServerSocket() {
|
||||
Disconnect();
|
||||
@@ -50,25 +114,26 @@ ServerSocket::~ServerSocket() {
|
||||
}
|
||||
|
||||
absl::Status ServerSocket::StartListening(int port) {
|
||||
if (listen_sockfd_ != kInvalidFd) {
|
||||
if (socket_info_->listen_sock != kInvalidSocket) {
|
||||
return MakeStatus("Already listening");
|
||||
}
|
||||
|
||||
LOG_DEBUG("Open socket");
|
||||
listen_sockfd_ = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (listen_sockfd_ < 0) {
|
||||
listen_sockfd_ = kInvalidFd;
|
||||
return MakeStatus("socket() failed: %s", strerror(errno));
|
||||
socket_info_->listen_sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (socket_info_->listen_sock == kInvalidSocket) {
|
||||
return MakeStatus("Creating listen socket failed: %s", GetLastErrorStr());
|
||||
}
|
||||
|
||||
// If the program terminates abnormally, the socket might remain in a
|
||||
// TIME_WAIT state and report "address already in use" on bind(). Setting
|
||||
// SO_REUSEADDR works around that. See
|
||||
// https://hea-www.harvard.edu/~fine/Tech/addrinuse.html
|
||||
int enable = 1;
|
||||
if (setsockopt(listen_sockfd_, SOL_SOCKET, SO_REUSEADDR, &enable,
|
||||
sizeof(enable)) < 0) {
|
||||
LOG_DEBUG("setsockopt() failed");
|
||||
const int enable = 1;
|
||||
int result =
|
||||
setsockopt(socket_info_->listen_sock, SOL_SOCKET, SO_REUSEADDR,
|
||||
reinterpret_cast<const char*>(&enable), sizeof(enable));
|
||||
if (result == kSocketError) {
|
||||
LOG_DEBUG("Enabling address reusal failed");
|
||||
}
|
||||
|
||||
LOG_DEBUG("Bind socket");
|
||||
@@ -77,46 +142,47 @@ absl::Status ServerSocket::StartListening(int port) {
|
||||
serv_addr.sin_family = AF_INET;
|
||||
serv_addr.sin_addr.s_addr = INADDR_ANY;
|
||||
serv_addr.sin_port = htons(port);
|
||||
if (bind(listen_sockfd_, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) <
|
||||
0) {
|
||||
|
||||
result = bind(socket_info_->listen_sock,
|
||||
reinterpret_cast<const SockAddrType*>(&serv_addr),
|
||||
sizeof(serv_addr));
|
||||
if (result == kSocketError) {
|
||||
int err = GetLastError();
|
||||
absl::Status status =
|
||||
MakeStatus("bind() to port %i failed: %s", port, strerror(errno));
|
||||
if (errno == EADDRINUSE) {
|
||||
MakeStatus("Binding to port %i failed: %s", port, GetErrorStr(err));
|
||||
if (err == kErrAddrInUse) {
|
||||
// Happens when two instances are run at the same time. Help callers to
|
||||
// print reasonable errors.
|
||||
status = SetTag(status, Tag::kAddressInUse);
|
||||
}
|
||||
close(listen_sockfd_);
|
||||
listen_sockfd_ = kInvalidFd;
|
||||
|
||||
Close(&socket_info_->listen_sock);
|
||||
return status;
|
||||
}
|
||||
|
||||
LOG_DEBUG("Listen");
|
||||
listen(listen_sockfd_, 1);
|
||||
result = listen(socket_info_->listen_sock, 1);
|
||||
if (result == kSocketError) {
|
||||
int err = GetLastError();
|
||||
Close(&socket_info_->listen_sock);
|
||||
return MakeStatus("Listening to socket failed: %s", GetErrorStr(err));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ServerSocket::StopListening() {
|
||||
if (listen_sockfd_ != kInvalidFd) {
|
||||
close(listen_sockfd_);
|
||||
listen_sockfd_ = kInvalidFd;
|
||||
}
|
||||
|
||||
Close(&socket_info_->listen_sock);
|
||||
LOG_INFO("Stopped listening.");
|
||||
}
|
||||
|
||||
absl::Status ServerSocket::WaitForConnection() {
|
||||
if (conn_sockfd_ != kInvalidFd) {
|
||||
if (socket_info_->conn_sock != kInvalidSocket) {
|
||||
return MakeStatus("Already connected");
|
||||
}
|
||||
|
||||
sockaddr_in cli_addr;
|
||||
socklen_t cli_len = sizeof(cli_addr);
|
||||
conn_sockfd_ = accept(listen_sockfd_, (struct sockaddr*)&cli_addr, &cli_len);
|
||||
if (conn_sockfd_ < 0) {
|
||||
conn_sockfd_ = kInvalidFd;
|
||||
return MakeStatus("accept() failed: %s", strerror(errno));
|
||||
socket_info_->conn_sock = accept(socket_info_->listen_sock, nullptr, nullptr);
|
||||
if (socket_info_->conn_sock == kInvalidSocket) {
|
||||
return MakeStatus("Accepting connection failed: %s", GetLastErrorStr());
|
||||
}
|
||||
|
||||
LOG_DEBUG("Client connected");
|
||||
@@ -124,39 +190,36 @@ absl::Status ServerSocket::WaitForConnection() {
|
||||
}
|
||||
|
||||
void ServerSocket::Disconnect() {
|
||||
if (conn_sockfd_ != kInvalidFd) {
|
||||
close(conn_sockfd_);
|
||||
conn_sockfd_ = kInvalidFd;
|
||||
}
|
||||
|
||||
Close(&socket_info_->conn_sock);
|
||||
LOG_INFO("Disconnected");
|
||||
}
|
||||
|
||||
absl::Status ServerSocket::ShutdownSendingEnd() {
|
||||
int result = shutdown(conn_sockfd_, SHUT_WR);
|
||||
if (result != 0) {
|
||||
return MakeStatus("shutdown() failed: %s", strerror(errno));
|
||||
int result = shutdown(socket_info_->conn_sock, kSendingEnd);
|
||||
if (result == kSocketError) {
|
||||
return MakeStatus("Socket shutdown failed: %s", GetLastErrorStr());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ServerSocket::Send(const void* buffer, size_t size) {
|
||||
const uint8_t* curr_ptr = reinterpret_cast<const uint8_t*>(buffer);
|
||||
ssize_t bytes_left = size;
|
||||
const char* curr_ptr = reinterpret_cast<const char*>(buffer);
|
||||
assert(size <= INT_MAX);
|
||||
int bytes_left = static_cast<int>(size);
|
||||
while (bytes_left > 0) {
|
||||
ssize_t bytes_written =
|
||||
HANDLE_EINTR(send(conn_sockfd_, curr_ptr, bytes_left, /*flags*/ 0));
|
||||
int bytes_written = HANDLE_EINTR(
|
||||
send(socket_info_->conn_sock, curr_ptr, bytes_left, /*flags*/ 0));
|
||||
|
||||
if (bytes_written < 0) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
const int err = GetLastError();
|
||||
if (err == kErrAgain || err == kErrWouldBlock) {
|
||||
// Shouldn't happen as the socket should be blocking.
|
||||
LOG_DEBUG("Socket would block");
|
||||
continue;
|
||||
}
|
||||
|
||||
return MakeStatus("write() to fd %i failed: %s", conn_sockfd_,
|
||||
strerror(errno));
|
||||
return MakeStatus("Sending to socket failed: %s", GetErrorStr(err));
|
||||
}
|
||||
|
||||
bytes_left -= bytes_written;
|
||||
@@ -173,21 +236,22 @@ absl::Status ServerSocket::Receive(void* buffer, size_t size,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
uint8_t* curr_ptr = reinterpret_cast<uint8_t*>(buffer);
|
||||
ssize_t bytes_left = size;
|
||||
char* curr_ptr = static_cast<char*>(buffer);
|
||||
assert(size <= INT_MAX);
|
||||
int bytes_left = size;
|
||||
while (bytes_left > 0) {
|
||||
ssize_t bytes_read =
|
||||
HANDLE_EINTR(recv(conn_sockfd_, curr_ptr, bytes_left, /*flags*/ 0));
|
||||
int bytes_read = HANDLE_EINTR(
|
||||
recv(socket_info_->conn_sock, curr_ptr, bytes_left, /*flags*/ 0));
|
||||
|
||||
if (bytes_read < 0) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
const int err = GetLastError();
|
||||
if (err == kErrAgain || err == kErrWouldBlock) {
|
||||
// Shouldn't happen as the socket should be blocking.
|
||||
LOG_DEBUG("Socket would block");
|
||||
continue;
|
||||
}
|
||||
|
||||
return MakeStatus("recv() from fd %i failed: %s", conn_sockfd_,
|
||||
strerror(errno));
|
||||
return MakeStatus("Receiving from socket failed: %s", GetErrorStr(err));
|
||||
}
|
||||
|
||||
bytes_left -= bytes_read;
|
||||
@@ -196,7 +260,7 @@ absl::Status ServerSocket::Receive(void* buffer, size_t size,
|
||||
|
||||
if (bytes_read == 0) {
|
||||
// EOF. Make sure we're not in the middle of a message.
|
||||
if (bytes_left < static_cast<ssize_t>(size)) {
|
||||
if (bytes_left < static_cast<int>(size)) {
|
||||
return MakeStatus("EOF after partial read");
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user