mirror of
https://github.com/nestriness/cdc-file-transfer.git
synced 2026-01-30 14:45:37 +02:00
Adds support for local syncs of files and folders on the same Windows machine, e.g. cdc_rsync C:\source C:\dest. The two main changes are - Skip the check whether the port is available remotely with PortManager. - Do not deploy cdc_rsync_server. - Run cdc_rsync_server directly, not through an SSH tunnel. The current implementation is not optimal as it starts cdc_rsync_server as a separate process and communicates to it via a TCP port.
323 lines
12 KiB
C++
323 lines
12 KiB
C++
// Copyright 2022 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "common/port_manager.h"
|
|
|
|
#define WIN32_LEAN_AND_MEAN
|
|
#include <windows.h>
|
|
|
|
#include <map>
|
|
|
|
#include "absl/strings/str_split.h"
|
|
#include "common/log.h"
|
|
#include "common/process.h"
|
|
#include "common/remote_util.h"
|
|
#include "common/status.h"
|
|
#include "common/status_macros.h"
|
|
#include "common/stopwatch.h"
|
|
#include "common/util.h"
|
|
|
|
namespace cdc_ft {
|
|
|
|
class SharedMemory {
|
|
public:
|
|
// Creates a new shared memory instance with given |name| and |size| in bytes.
|
|
// Different instances with matching names reference the same piece of memory,
|
|
// even if they belong to different processes. If shared memory with the given
|
|
// |name| already exists, the existing memory is referenced. Otherwise, a new
|
|
// piece of memory is allocated and zero-initialized.
|
|
SharedMemory(std::string name, size_t size)
|
|
: name_(std::move(name)), size_(size) {}
|
|
|
|
absl::StatusOr<void*> Get() {
|
|
// Already initialized?
|
|
if (shared_mem_) return shared_mem_;
|
|
assert(!map_file_handle_);
|
|
|
|
LARGE_INTEGER size;
|
|
size.QuadPart = size_;
|
|
map_file_handle_ = CreateFileMapping(
|
|
INVALID_HANDLE_VALUE, // use paging file
|
|
nullptr, // default security
|
|
PAGE_READWRITE, // read/write access
|
|
size.HighPart, // maximum object size (high-order DWORD)
|
|
size.LowPart, // maximum object size (low-order DWORD)
|
|
Util::Utf8ToWideStr(name_).c_str()); // name of mapping object
|
|
|
|
if (!map_file_handle_) {
|
|
return MakeStatus("Failed to create file mapping object: %s",
|
|
Util::GetLastWin32Error());
|
|
}
|
|
|
|
// The shared memory holds the timestamps when the ports were reserved.
|
|
shared_mem_ = MapViewOfFile(map_file_handle_, // handle to map object
|
|
FILE_MAP_ALL_ACCESS, // read/write permission
|
|
0, 0, size.QuadPart);
|
|
|
|
if (!shared_mem_) {
|
|
std::string errorMessage = Util::GetLastWin32Error();
|
|
CloseHandle(map_file_handle_);
|
|
map_file_handle_ = nullptr;
|
|
return MakeStatus("Failed to map view of file: %s", errorMessage);
|
|
}
|
|
|
|
return shared_mem_;
|
|
}
|
|
|
|
~SharedMemory() {
|
|
if (shared_mem_) {
|
|
UnmapViewOfFile(shared_mem_);
|
|
shared_mem_ = nullptr;
|
|
}
|
|
|
|
if (map_file_handle_) {
|
|
CloseHandle(map_file_handle_);
|
|
map_file_handle_ = nullptr;
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::string name_;
|
|
size_t size_;
|
|
HANDLE map_file_handle_ = nullptr;
|
|
void* shared_mem_ = nullptr;
|
|
};
|
|
|
|
PortManager::PortManager(std::string name, int first_port, int last_port,
|
|
ProcessFactory* process_factory,
|
|
RemoteUtil* remote_util, SystemClock* system_clock,
|
|
SteadyClock* steady_clock)
|
|
: first_port_(first_port),
|
|
last_port_(last_port),
|
|
process_factory_(process_factory),
|
|
remote_util_(remote_util),
|
|
system_clock_(system_clock),
|
|
steady_clock_(steady_clock),
|
|
shared_mem_(std::make_unique<SharedMemory>(
|
|
std::move(name), (last_port - first_port + 1) * sizeof(time_t))) {
|
|
assert(last_port_ >= first_port_);
|
|
}
|
|
|
|
PortManager::~PortManager() {
|
|
std::vector<int> ports_copy;
|
|
ports_copy.insert(ports_copy.end(), reserved_ports_.begin(),
|
|
reserved_ports_.end());
|
|
for (int port : ports_copy) {
|
|
absl::Status status = ReleasePort(port);
|
|
if (!status.ok()) {
|
|
LOG_WARNING("Failed to release port %d: %s", port, status.ToString());
|
|
}
|
|
}
|
|
}
|
|
|
|
absl::StatusOr<int> PortManager::ReservePort(int remote_timeout_sec) {
|
|
// Find available port on workstation.
|
|
std::unordered_set<int> local_ports;
|
|
ASSIGN_OR_RETURN(local_ports,
|
|
FindAvailableLocalPorts(first_port_, last_port_, "127.0.0.1",
|
|
process_factory_),
|
|
"Failed to find available ports on workstation");
|
|
|
|
// Find available port on remote instance.
|
|
std::unordered_set<int> remote_ports = local_ports;
|
|
if (remote_util_ != nullptr) {
|
|
ASSIGN_OR_RETURN(remote_ports,
|
|
FindAvailableRemotePorts(
|
|
first_port_, last_port_, "0.0.0.0", process_factory_,
|
|
remote_util_, remote_timeout_sec, steady_clock_),
|
|
"Failed to find available ports on instance");
|
|
}
|
|
|
|
// Fetch shared memory.
|
|
void* mem;
|
|
ASSIGN_OR_RETURN(mem, shared_mem_->Get(), "Failed to get shared memory");
|
|
time_t* port_timestamps = static_cast<time_t*>(mem);
|
|
|
|
// Put ports into a multimap to iterate in LRU order.
|
|
int num_ports = last_port_ - first_port_ + 1;
|
|
std::multimap<time_t, int> ports_to_index;
|
|
for (int n = 0; n < num_ports; ++n) {
|
|
ports_to_index.insert({port_timestamps[n], n});
|
|
}
|
|
|
|
// Iterate over the ports, unused first (timestamp 0), the rest in LRU order.
|
|
// The ones with timestamps != 0 might either be stuck (e.g. process crashed
|
|
// and did not release port) or still in use.
|
|
const time_t now = std::chrono::system_clock::to_time_t(system_clock_->Now());
|
|
for (const auto& [port_timestamp, n] : ports_to_index) {
|
|
// Note that some other process might have hijacked the port in the
|
|
// meantime, hence do an InterlockedCompareExchange.
|
|
volatile time_t* ts_ptr = &port_timestamps[n];
|
|
static_assert(sizeof(time_t) == sizeof(uint64_t), "time_t must be 64 bit");
|
|
assert((reinterpret_cast<uintptr_t>(ts_ptr) & 7) == 0);
|
|
if (InterlockedCompareExchange64(ts_ptr, now, port_timestamp) ==
|
|
port_timestamp) {
|
|
int port = first_port_ + n;
|
|
LOG_DEBUG("Trying to reserve port %i", port);
|
|
|
|
// We have reserved this port. Double-check that it's actually not in use
|
|
// on both the workstation and the server.
|
|
if (local_ports.find(port) == local_ports.end()) {
|
|
LOG_DEBUG("Port %i not available on workstation", port);
|
|
InterlockedCompareExchange64(ts_ptr, now, port_timestamp);
|
|
continue;
|
|
}
|
|
if (remote_ports.find(port) == remote_ports.end()) {
|
|
LOG_DEBUG("Port %i not available on instance", port);
|
|
InterlockedCompareExchange64(ts_ptr, now, port_timestamp);
|
|
continue;
|
|
}
|
|
|
|
LOG_DEBUG("Port %i is available on workstation and instance", port);
|
|
reserved_ports_.insert(port);
|
|
return port;
|
|
}
|
|
}
|
|
|
|
return absl::ResourceExhaustedError(absl::StrFormat(
|
|
"No port available in range [%i, %i]", first_port_, last_port_));
|
|
}
|
|
|
|
absl::Status PortManager::ReleasePort(int port) {
|
|
if (reserved_ports_.find(port) == reserved_ports_.end())
|
|
return absl::OkStatus();
|
|
void* mem;
|
|
ASSIGN_OR_RETURN(mem, shared_mem_->Get(), "Failed to get shared memory");
|
|
time_t* port_timestamps = static_cast<time_t*>(mem);
|
|
volatile time_t* ts_ptr = &port_timestamps[port - first_port_];
|
|
InterlockedExchange64(ts_ptr, 0);
|
|
reserved_ports_.erase(port);
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
// static
|
|
absl::StatusOr<std::unordered_set<int>> PortManager::FindAvailableLocalPorts(
|
|
int first_port, int last_port, const char* ip,
|
|
ProcessFactory* process_factory) {
|
|
// -a to get the connection and ports the computer is listening on.
|
|
// -n to get numerical addresses to avoid the overhead of determining names.
|
|
// -p tcp to limit the output to TCPv4 connections.
|
|
// TODO: Use Windows API instead of netstat.
|
|
ProcessStartInfo start_info;
|
|
start_info.command = "netstat -a -n -p tcp";
|
|
start_info.name = "netstat";
|
|
start_info.flags = ProcessFlags::kNoWindow;
|
|
|
|
std::string output;
|
|
start_info.stdout_handler = [&output](const char* data, size_t data_size) {
|
|
output.append(data, data_size);
|
|
return absl::OkStatus();
|
|
};
|
|
std::string errors;
|
|
start_info.stderr_handler = [&errors](const char* data, size_t data_size) {
|
|
errors.append(data, data_size);
|
|
return absl::OkStatus();
|
|
};
|
|
|
|
absl::Status status = process_factory->Run(start_info);
|
|
if (!status.ok()) {
|
|
return WrapStatus(status, "Failed to run netstat:\n%s", errors);
|
|
}
|
|
|
|
LOG_DEBUG("netstat (workstation) output:\n%s", output);
|
|
return FindAvailablePorts(first_port, last_port, output, ip);
|
|
}
|
|
|
|
// static
|
|
absl::StatusOr<std::unordered_set<int>> PortManager::FindAvailableRemotePorts(
|
|
int first_port, int last_port, const char* ip,
|
|
ProcessFactory* process_factory, RemoteUtil* remote_util, int timeout_sec,
|
|
SteadyClock* steady_clock) {
|
|
// --numeric to get numerical addresses.
|
|
// --listening to get only listening sockets.
|
|
// --tcp to get only TCP connections.
|
|
std::string remote_command = "netstat --numeric --listening --tcp";
|
|
ProcessStartInfo start_info =
|
|
remote_util->BuildProcessStartInfoForSsh(remote_command);
|
|
start_info.name = "netstat";
|
|
start_info.flags = ProcessFlags::kNoWindow;
|
|
|
|
std::string output;
|
|
start_info.stdout_handler = [&output](const char* data, size_t data_size) {
|
|
output.append(data, data_size);
|
|
return absl::OkStatus();
|
|
};
|
|
std::string errors;
|
|
start_info.stderr_handler = [&errors](const char* data, size_t data_size) {
|
|
errors.append(data, data_size);
|
|
return absl::OkStatus();
|
|
};
|
|
|
|
std::unique_ptr<Process> process = process_factory->Create(start_info);
|
|
absl::Status status = process->Start();
|
|
if (!status.ok()) return WrapStatus(status, "Failed to start netstat");
|
|
|
|
Stopwatch timeout_timer(steady_clock);
|
|
bool is_timeout = false;
|
|
auto detect_timeout = [&timeout_timer, timeout_sec, &is_timeout]() {
|
|
is_timeout = timeout_timer.ElapsedSeconds() > timeout_sec;
|
|
return is_timeout;
|
|
};
|
|
status = process->RunUntil(detect_timeout);
|
|
if (!status.ok()) return WrapStatus(status, "Failed to run netstat process");
|
|
if (is_timeout)
|
|
return absl::DeadlineExceededError("Timeout while running netstat");
|
|
|
|
uint32_t exit_code = process->ExitCode();
|
|
if (exit_code != 0) {
|
|
return MakeStatus("netstat process exited with code %u:\n%s", exit_code,
|
|
errors);
|
|
}
|
|
|
|
LOG_DEBUG("netstat (instance) output:\n%s", output);
|
|
return FindAvailablePorts(first_port, last_port, output, ip);
|
|
}
|
|
|
|
// static
|
|
absl::StatusOr<std::unordered_set<int>> PortManager::FindAvailablePorts(
|
|
int first_port, int last_port, const std::string& netstat_output,
|
|
const char* ip) {
|
|
std::unordered_set<int> available_ports;
|
|
std::vector<std::string> lines;
|
|
for (const auto& line : absl::StrSplit(netstat_output, '\n')) {
|
|
if (absl::StrContains(line, ip)) {
|
|
lines.push_back(std::string(line));
|
|
}
|
|
}
|
|
|
|
for (int port = first_port; port <= last_port; ++port) {
|
|
bool port_occupied = false;
|
|
std::string portToken = absl::StrFormat("%s:%i", ip, port);
|
|
for (const std::string& line : lines) {
|
|
// Ports in the TIME_WAIT state can be reused. It is common that ports
|
|
// stay in this state for O(minutes).
|
|
if (absl::StrContains(line, portToken) &&
|
|
!absl::StrContains(line, "TIME_WAIT")) {
|
|
port_occupied = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!port_occupied) available_ports.insert(port);
|
|
}
|
|
|
|
if (available_ports.empty()) {
|
|
return absl::ResourceExhaustedError(absl::StrFormat(
|
|
"No port available in range [%i, %i]", first_port, last_port));
|
|
}
|
|
|
|
return available_ports;
|
|
}
|
|
|
|
} // namespace cdc_ft
|