Files
netris-cdc-file-transfer/common/port_manager_win.cc
Lutz Justen d8c2b5906e [cdc_stream] [cdc_rsync] Add --forward-port flag (#45)
Adds a flag to set the SSH forwarding port or port range used for
'cdc_stream start-service' and 'cdc_rsync'.

If a single number is passed, e.g. --forward-port 12345, then this
port is used without checking availability of local and remote ports.
If the port is taken, this results in an error when trying to connect.
Note that this restricts the number of connections that stream can
make to one.

If a range is passed, e.g. --forward-port 45000-46000, the tools
search for available ports locally and remotely in that range. This is
more robust, but a bit slower due to the extra overhead.

Optimizes port_manager_win as it was very slow for a large port range.
It's still not optimal, but the time needed to scan 30k ports is
<< 1 seconds now.

Fixes #12
2022-12-19 10:04:36 +01:00

321 lines
12 KiB
C++

// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/port_manager.h"
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#include <map>
#include "absl/strings/str_split.h"
#include "common/log.h"
#include "common/process.h"
#include "common/remote_util.h"
#include "common/status.h"
#include "common/status_macros.h"
#include "common/stopwatch.h"
#include "common/util.h"
namespace cdc_ft {
class SharedMemory {
public:
// Creates a new shared memory instance with given |name| and |size| in bytes.
// Different instances with matching names reference the same piece of memory,
// even if they belong to different processes. If shared memory with the given
// |name| already exists, the existing memory is referenced. Otherwise, a new
// piece of memory is allocated and zero-initialized.
SharedMemory(std::string name, size_t size)
: name_(std::move(name)), size_(size) {}
absl::StatusOr<void*> Get() {
// Already initialized?
if (shared_mem_) return shared_mem_;
assert(!map_file_handle_);
LARGE_INTEGER size;
size.QuadPart = size_;
map_file_handle_ = CreateFileMapping(
INVALID_HANDLE_VALUE, // use paging file
nullptr, // default security
PAGE_READWRITE, // read/write access
size.HighPart, // maximum object size (high-order DWORD)
size.LowPart, // maximum object size (low-order DWORD)
Util::Utf8ToWideStr(name_).c_str()); // name of mapping object
if (!map_file_handle_) {
return MakeStatus("Failed to create file mapping object: %s",
Util::GetLastWin32Error());
}
// The shared memory holds the timestamps when the ports were reserved.
shared_mem_ = MapViewOfFile(map_file_handle_, // handle to map object
FILE_MAP_ALL_ACCESS, // read/write permission
0, 0, size.QuadPart);
if (!shared_mem_) {
std::string errorMessage = Util::GetLastWin32Error();
CloseHandle(map_file_handle_);
map_file_handle_ = nullptr;
return MakeStatus("Failed to map view of file: %s", errorMessage);
}
return shared_mem_;
}
~SharedMemory() {
if (shared_mem_) {
UnmapViewOfFile(shared_mem_);
shared_mem_ = nullptr;
}
if (map_file_handle_) {
CloseHandle(map_file_handle_);
map_file_handle_ = nullptr;
}
}
private:
std::string name_;
size_t size_;
HANDLE map_file_handle_ = nullptr;
void* shared_mem_ = nullptr;
};
PortManager::PortManager(std::string name, int first_port, int last_port,
ProcessFactory* process_factory,
RemoteUtil* remote_util, SystemClock* system_clock,
SteadyClock* steady_clock)
: first_port_(first_port),
last_port_(last_port),
process_factory_(process_factory),
remote_util_(remote_util),
system_clock_(system_clock),
steady_clock_(steady_clock),
shared_mem_(std::make_unique<SharedMemory>(
std::move(name), (last_port - first_port + 1) * sizeof(time_t))) {
assert(last_port_ >= first_port_);
}
PortManager::~PortManager() {
std::vector<int> ports_copy;
ports_copy.insert(ports_copy.end(), reserved_ports_.begin(),
reserved_ports_.end());
for (int port : ports_copy) {
absl::Status status = ReleasePort(port);
if (!status.ok()) {
LOG_WARNING("Failed to release port %d: %s", port, status.ToString());
}
}
}
absl::StatusOr<int> PortManager::ReservePort(int remote_timeout_sec) {
// Find available port on workstation.
std::unordered_set<int> local_ports;
ASSIGN_OR_RETURN(local_ports,
FindAvailableLocalPorts(first_port_, last_port_, "127.0.0.1",
process_factory_),
"Failed to find available ports on workstation");
// Find available port on remote instance.
std::unordered_set<int> remote_ports = local_ports;
ASSIGN_OR_RETURN(remote_ports,
FindAvailableRemotePorts(first_port_, last_port_, "0.0.0.0",
process_factory_, remote_util_,
remote_timeout_sec, steady_clock_),
"Failed to find available ports on instance");
// Fetch shared memory.
void* mem;
ASSIGN_OR_RETURN(mem, shared_mem_->Get(), "Failed to get shared memory");
time_t* port_timestamps = static_cast<time_t*>(mem);
// Put ports into a multimap to iterate in LRU order.
int num_ports = last_port_ - first_port_ + 1;
std::multimap<time_t, int> ports_to_index;
for (int n = 0; n < num_ports; ++n) {
ports_to_index.insert({port_timestamps[n], n});
}
// Iterate over the ports, unused first (timestamp 0), the rest in LRU order.
// The ones with timestamps != 0 might either be stuck (e.g. process crashed
// and did not release port) or still in use.
const time_t now = std::chrono::system_clock::to_time_t(system_clock_->Now());
for (const auto& [port_timestamp, n] : ports_to_index) {
// Note that some other process might have hijacked the port in the
// meantime, hence do an InterlockedCompareExchange.
volatile time_t* ts_ptr = &port_timestamps[n];
static_assert(sizeof(time_t) == sizeof(uint64_t), "time_t must be 64 bit");
assert((reinterpret_cast<uintptr_t>(ts_ptr) & 7) == 0);
if (InterlockedCompareExchange64(ts_ptr, now, port_timestamp) ==
port_timestamp) {
int port = first_port_ + n;
LOG_DEBUG("Trying to reserve port %i", port);
// We have reserved this port. Double-check that it's actually not in use
// on both the workstation and the server.
if (local_ports.find(port) == local_ports.end()) {
LOG_DEBUG("Port %i not available on workstation", port);
InterlockedCompareExchange64(ts_ptr, now, port_timestamp);
continue;
}
if (remote_ports.find(port) == remote_ports.end()) {
LOG_DEBUG("Port %i not available on instance", port);
InterlockedCompareExchange64(ts_ptr, now, port_timestamp);
continue;
}
LOG_DEBUG("Port %i is available on workstation and instance", port);
reserved_ports_.insert(port);
return port;
}
}
return absl::ResourceExhaustedError(absl::StrFormat(
"No port available in range [%i, %i]", first_port_, last_port_));
}
absl::Status PortManager::ReleasePort(int port) {
if (reserved_ports_.find(port) == reserved_ports_.end())
return absl::OkStatus();
void* mem;
ASSIGN_OR_RETURN(mem, shared_mem_->Get(), "Failed to get shared memory");
time_t* port_timestamps = static_cast<time_t*>(mem);
volatile time_t* ts_ptr = &port_timestamps[port - first_port_];
InterlockedExchange64(ts_ptr, 0);
reserved_ports_.erase(port);
return absl::OkStatus();
}
// static
absl::StatusOr<std::unordered_set<int>> PortManager::FindAvailableLocalPorts(
int first_port, int last_port, const char* ip,
ProcessFactory* process_factory) {
// -a to get the connection and ports the computer is listening on.
// -n to get numerical addresses to avoid the overhead of determining names.
// -p tcp to limit the output to TCPv4 connections.
// TODO: Use Windows API instead of netstat.
ProcessStartInfo start_info;
start_info.command = "netstat -a -n -p tcp";
start_info.name = "netstat";
start_info.flags = ProcessFlags::kNoWindow;
std::string output;
start_info.stdout_handler = [&output](const char* data, size_t data_size) {
output.append(data, data_size);
return absl::OkStatus();
};
std::string errors;
start_info.stderr_handler = [&errors](const char* data, size_t data_size) {
errors.append(data, data_size);
return absl::OkStatus();
};
absl::Status status = process_factory->Run(start_info);
if (!status.ok()) {
return WrapStatus(status, "Failed to run netstat:\n%s", errors);
}
LOG_DEBUG("netstat (workstation) output:\n%s", output);
return FindAvailablePorts(first_port, last_port, output, ip);
}
// static
absl::StatusOr<std::unordered_set<int>> PortManager::FindAvailableRemotePorts(
int first_port, int last_port, const char* ip,
ProcessFactory* process_factory, RemoteUtil* remote_util, int timeout_sec,
SteadyClock* steady_clock) {
// --numeric to get numerical addresses.
// --listening to get only listening sockets.
// --tcp to get only TCP connections.
std::string remote_command = "netstat --numeric --listening --tcp";
ProcessStartInfo start_info =
remote_util->BuildProcessStartInfoForSsh(remote_command);
start_info.name = "netstat";
start_info.flags = ProcessFlags::kNoWindow;
std::string output;
start_info.stdout_handler = [&output](const char* data, size_t data_size) {
output.append(data, data_size);
return absl::OkStatus();
};
std::string errors;
start_info.stderr_handler = [&errors](const char* data, size_t data_size) {
errors.append(data, data_size);
return absl::OkStatus();
};
std::unique_ptr<Process> process = process_factory->Create(start_info);
absl::Status status = process->Start();
if (!status.ok()) return WrapStatus(status, "Failed to start netstat");
Stopwatch timeout_timer(steady_clock);
bool is_timeout = false;
auto detect_timeout = [&timeout_timer, timeout_sec, &is_timeout]() {
is_timeout = timeout_timer.ElapsedSeconds() > timeout_sec;
return is_timeout;
};
status = process->RunUntil(detect_timeout);
if (!status.ok()) return WrapStatus(status, "Failed to run netstat process");
if (is_timeout)
return absl::DeadlineExceededError("Timeout while running netstat");
uint32_t exit_code = process->ExitCode();
if (exit_code != 0) {
return MakeStatus("netstat process exited with code %u:\n%s", exit_code,
errors);
}
LOG_DEBUG("netstat (instance) output:\n%s", output);
return FindAvailablePorts(first_port, last_port, output, ip);
}
// static
absl::StatusOr<std::unordered_set<int>> PortManager::FindAvailablePorts(
int first_port, int last_port, const std::string& netstat_output,
const char* ip) {
std::unordered_set<int> available_ports;
std::vector<std::string> lines;
for (const auto& line : absl::StrSplit(netstat_output, '\n')) {
if (absl::StrContains(line, ip)) {
lines.push_back(std::string(line));
}
}
for (int port = first_port; port <= last_port; ++port) {
bool port_occupied = false;
std::string portToken = absl::StrFormat("%s:%i", ip, port);
for (const std::string& line : lines) {
// Ports in the TIME_WAIT state can be reused. It is common that ports
// stay in this state for O(minutes).
if (absl::StrContains(line, portToken) &&
!absl::StrContains(line, "TIME_WAIT")) {
port_occupied = true;
break;
}
}
if (!port_occupied) available_ports.insert(port);
}
if (available_ports.empty()) {
return absl::ResourceExhaustedError(absl::StrFormat(
"No port available in range [%i, %i]", first_port, last_port));
}
return available_ports;
}
} // namespace cdc_ft