diff --git a/cdc_rsync_server/BUILD b/cdc_rsync_server/BUILD index 8b69391..28f22dd 100644 --- a/cdc_rsync_server/BUILD +++ b/cdc_rsync_server/BUILD @@ -140,6 +140,7 @@ cc_library( "//common:status", "//common:util", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) diff --git a/cdc_rsync_server/cdc_rsync_server.cc b/cdc_rsync_server/cdc_rsync_server.cc index 95d21c1..922b12a 100644 --- a/cdc_rsync_server/cdc_rsync_server.cc +++ b/cdc_rsync_server/cdc_rsync_server.cc @@ -24,6 +24,7 @@ #include "common/log.h" #include "common/path.h" #include "common/status.h" +#include "common/status_macros.h" #include "common/stopwatch.h" #include "common/threadpool.h" #include "common/util.h" @@ -295,10 +296,11 @@ absl::Status CdcRsyncServer::Run(int port) { socket_finalizer_ = std::make_unique(); socket_ = std::make_unique(); - status = socket_->StartListening(port); - if (!status.ok()) { - return WrapStatus(status, "Failed to start listening on port %i", port); - } + int new_port; + ASSIGN_OR_RETURN(new_port, socket_->StartListening(port), + "Failed to start listening on port %i", port); + assert(port != 0); + assert(port == new_port); LOG_INFO("cdc_rsync_server listening on port %i", port); // This is the marker for the client, so it knows it can connect. diff --git a/cdc_rsync_server/server_socket.cc b/cdc_rsync_server/server_socket.cc index f76be6e..e5bc202 100644 --- a/cdc_rsync_server/server_socket.cc +++ b/cdc_rsync_server/server_socket.cc @@ -22,11 +22,14 @@ #if PLATFORM_WINDOWS #include +#include #elif PLATFORM_LINUX +#include #include #include +#include #include #include @@ -95,6 +98,15 @@ void Close(SocketType* socket) { std::string GetLastErrorStr() { return GetErrorStr(GetLastError()); } +class AddrInfoReleaser { + public: + AddrInfoReleaser(addrinfo* addr_infos) : addr_infos_(addr_infos) {} + ~AddrInfoReleaser() { freeaddrinfo(addr_infos_); } + + private: + addrinfo* addr_infos_; +}; + } // namespace struct ServerSocketInfo { @@ -113,15 +125,57 @@ ServerSocket::~ServerSocket() { StopListening(); } -absl::Status ServerSocket::StartListening(int port) { +absl::StatusOr ServerSocket::StartListening(int port) { if (socket_info_->listen_sock != kInvalidSocket) { return MakeStatus("Already listening"); } - LOG_DEBUG("Open socket"); - socket_info_->listen_sock = socket(AF_INET, SOCK_STREAM, 0); + // Find addrinfos suitable for listening via IPV4 and IPV6. + addrinfo hints; + addrinfo* addr_infos = nullptr; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + // AI_PASSIVE indicates that the addresses are used with bind(). The returned + // addresses will be the unspecified addresses for each family. + hints.ai_flags = AI_NUMERICHOST | AI_PASSIVE; + int result = getaddrinfo(/*address=*/nullptr, std::to_string(port).c_str(), + &hints, &addr_infos); + if (result != 0) { + return MakeStatus("Getting address infos failed: %s", GetLastErrorStr()); + } + AddrInfoReleaser releaser(addr_infos); + + // Prefer IPV6 sockets. They can also accept IPV4 connections. + for (addrinfo* curr = addr_infos; curr; curr = curr->ai_next) { + if (curr->ai_family == PF_INET6) { + return StartListeningInternal(port, curr); + } + } + + // Fall back to IPV4 sockets. + for (addrinfo* curr = addr_infos; curr; curr = curr->ai_next) { + if (curr->ai_family == PF_INET) { + return StartListeningInternal(port, curr); + } + } + + return MakeStatus("No IPV4 and IPV6 network addresses available"); +} + +absl::StatusOr ServerSocket::StartListeningInternal(int port, + addrinfo* addr) { + assert(addr->ai_family == PF_INET || addr->ai_family == PF_INET6); + const char* family = addr->ai_family == PF_INET ? "IPV4" : "IPV6"; + + // Open a socket with the correct address family for this address. + LOG_DEBUG("Open %s listen socket", family); + socket_info_->listen_sock = + socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); if (socket_info_->listen_sock == kInvalidSocket) { - return MakeStatus("Creating listen socket failed: %s", GetLastErrorStr()); + return MakeStatus("Creating %s listen socket failed: %s", family, + GetLastErrorStr()); } // If the program terminates abnormally, the socket might remain in a @@ -136,16 +190,21 @@ absl::Status ServerSocket::StartListening(int port) { LOG_DEBUG("Enabling address reusal failed"); } - LOG_DEBUG("Bind socket"); - sockaddr_in serv_addr; - memset(&serv_addr, 0, sizeof(serv_addr)); - serv_addr.sin_family = AF_INET; - serv_addr.sin_addr.s_addr = INADDR_ANY; - serv_addr.sin_port = htons(port); + // Allow ipv4 connections on the ipv6 socket. By default, ipv6 sockets only + // allow ipv4 connections on Windows. + if (addr->ai_family == PF_INET6) { + const int disable = 0; + result = + setsockopt(socket_info_->listen_sock, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast(&disable), sizeof(disable)); + if (result == kSocketError) { + LOG_DEBUG("Disabling IPV6-only failed"); + } + } - result = bind(socket_info_->listen_sock, - reinterpret_cast(&serv_addr), - sizeof(serv_addr)); + LOG_DEBUG("Bind socket"); + result = bind(socket_info_->listen_sock, addr->ai_addr, + static_cast(addr->ai_addrlen)); if (result == kSocketError) { int err = GetLastError(); absl::Status status = @@ -159,6 +218,21 @@ absl::Status ServerSocket::StartListening(int port) { return status; } + if (port == 0) { + // Find out which port was auto-selected. + socklen_t len = addr->ai_addrlen; + result = getsockname(socket_info_->listen_sock, addr->ai_addr, &len); + if (result == kSocketError) { + Close(&socket_info_->listen_sock); + return MakeStatus("Getting port failed: %s", GetLastErrorStr()); + } + if (addr->ai_family == PF_INET) { + port = ntohs(reinterpret_cast(addr->ai_addr)->sin_port); + } else if (addr->ai_family == PF_INET6) { + port = ntohs(reinterpret_cast(addr->ai_addr)->sin6_port); + } + } + LOG_DEBUG("Listen"); result = listen(socket_info_->listen_sock, 1); if (result == kSocketError) { @@ -167,7 +241,7 @@ absl::Status ServerSocket::StartListening(int port) { return MakeStatus("Listening to socket failed: %s", GetErrorStr(err)); } - return absl::OkStatus(); + return port; } void ServerSocket::StopListening() { @@ -179,6 +253,9 @@ absl::Status ServerSocket::WaitForConnection() { if (socket_info_->conn_sock != kInvalidSocket) { return MakeStatus("Already connected"); } + if (socket_info_->listen_sock == kInvalidSocket) { + return MakeStatus("Not listening"); + } socket_info_->conn_sock = accept(socket_info_->listen_sock, nullptr, nullptr); if (socket_info_->conn_sock == kInvalidSocket) { diff --git a/cdc_rsync_server/server_socket.h b/cdc_rsync_server/server_socket.h index 636b949..17bcb40 100644 --- a/cdc_rsync_server/server_socket.h +++ b/cdc_rsync_server/server_socket.h @@ -18,8 +18,11 @@ #define CDC_RSYNC_SERVER_SERVER_SOCKET_H_ #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "cdc_rsync/base/socket.h" +struct addrinfo; + namespace cdc_ft { class ServerSocket : public Socket { @@ -28,7 +31,9 @@ class ServerSocket : public Socket { ~ServerSocket(); // Starts listening for connections on |port|. - absl::Status StartListening(int port); + // Passing 0 as port will bind to any available port. + // Returns the port that was bound to. + absl::StatusOr StartListening(int port); // Stops listening for connections. No-op if already stopped/never started. void StopListening(); @@ -50,6 +55,11 @@ class ServerSocket : public Socket { size_t* bytes_received) override; private: + // Called by StartListening() for a specific IPV4 or IPV6 |addr_info|. + // Passing 0 as port will bind to any available port. + // Returns the port that was bound to. + absl::StatusOr StartListeningInternal(int port, addrinfo* addr); + std::unique_ptr socket_info_; };