diff --git a/all_files.vcxitems b/all_files.vcxitems index 75efda4..27a5a22 100644 --- a/all_files.vcxitems +++ b/all_files.vcxitems @@ -41,6 +41,8 @@ + + @@ -159,6 +161,7 @@ + diff --git a/cdc_rsync/BUILD b/cdc_rsync/BUILD index e285750..4c648e0 100644 --- a/cdc_rsync/BUILD +++ b/cdc_rsync/BUILD @@ -179,6 +179,7 @@ cc_library( srcs = ["server_arch.cc"], hdrs = ["server_arch.h"], deps = [ + "//common:arch_type", "//common:path", "//common:remote_util", "@com_google_absl//absl/strings", diff --git a/cdc_rsync/cdc_rsync_client.cc b/cdc_rsync/cdc_rsync_client.cc index e7798b7..e823d65 100644 --- a/cdc_rsync/cdc_rsync_client.cc +++ b/cdc_rsync/cdc_rsync_client.cc @@ -127,16 +127,35 @@ CdcRsyncClient::~CdcRsyncClient() { } absl::Status CdcRsyncClient::Run() { - int port; - ASSIGN_OR_RETURN(port, FindAvailablePort(), "Failed to find available port"); + // If |remote_util_| is not set, it's a local sync. Otherwise, guess the + // architecture of the device that runs cdc_rsync_server from the destination + // path, e.g. "C:\path\to\dest" strongly indicates Windows. + ServerArch server_arch = !remote_util_ + ? ServerArch::DetectFromLocalDevice() + : ServerArch::GuessFromDestination(destination_); - // If |remote_util_| is not set, it's a local sync. - ServerArch::Type arch_type = - remote_util_ ? ServerArch::Detect(destination_) : ServerArch::LocalType(); - ServerArch server_arch(arch_type); + int port; + ASSIGN_OR_RETURN(port, FindAvailablePort(&server_arch), + "Failed to find available port"); // Start the server process. absl::Status status = StartServer(port, server_arch); + if (HasTag(status, Tag::kDeployServer) && server_arch.IsGuess() && + server_exit_code_ != kServerExitCodeOutOfDate) { + // Server couldn't be run, e.g. not found or failed to start. + // Check whether we guessed the arch type wrong and try again. + // Note that in case of a local sync, or if the server actively reported + // that it's out-dated, there's no need to detect the arch. + const ArchType old_type = server_arch.GetType(); + ASSIGN_OR_RETURN(server_arch, + ServerArch::DetectFromRemoteDevice(remote_util_.get())); + if (server_arch.GetType() != old_type) { + LOG_DEBUG("Guessed server arch type wrong, guessed %s, actual %s.", + GetArchTypeStr(old_type), server_arch.GetTypeStr()); + status = StartServer(port, server_arch); + } + } + if (HasTag(status, Tag::kDeployServer)) { // Gamelet components are not deployed or out-dated. Deploy and retry. status = DeployServer(server_arch); @@ -176,15 +195,17 @@ absl::Status CdcRsyncClient::Run() { return status; } -absl::StatusOr CdcRsyncClient::FindAvailablePort() { +absl::StatusOr CdcRsyncClient::FindAvailablePort(ServerArch* server_arch) { // Find available local and remote ports for port forwarding. // If only one port is in the given range, try that without checking. if (options_.forward_port_first >= options_.forward_port_last) { return options_.forward_port_first; } - absl::StatusOr port = - port_manager_->ReservePort(options_.connection_timeout_sec); + assert(server_arch); + absl::StatusOr port = port_manager_->ReservePort( + options_.connection_timeout_sec, server_arch->GetType()); + if (absl::IsDeadlineExceeded(port.status())) { // Server didn't respond in time. return SetTag(port.status(), Tag::kConnectionTimeout); @@ -193,6 +214,21 @@ absl::StatusOr CdcRsyncClient::FindAvailablePort() { // Port in use. return SetTag(port.status(), Tag::kAddressInUse); } + + // If |server_arch| was guessed, calling netstat might have failed because + // the arch was wrong. Properly detect it and try again if it changed. + if (!port.ok() && server_arch->IsGuess()) { + const ArchType old_type = server_arch->GetType(); + ASSIGN_OR_RETURN(*server_arch, + ServerArch::DetectFromRemoteDevice(remote_util_.get())); + assert(!server_arch->IsGuess()); + if (server_arch->GetType() != old_type) { + LOG_DEBUG("Guessed server arch type wrong, guessed %s, actual %s.", + GetArchTypeStr(old_type), server_arch->GetTypeStr()); + return FindAvailablePort(server_arch); + } + } + return port; } diff --git a/cdc_rsync/cdc_rsync_client.h b/cdc_rsync/cdc_rsync_client.h index 96e9953..0f1c415 100644 --- a/cdc_rsync/cdc_rsync_client.h +++ b/cdc_rsync/cdc_rsync_client.h @@ -79,7 +79,9 @@ class CdcRsyncClient { private: // Finds available local and remote ports for port forwarding. - absl::StatusOr FindAvailablePort(); + // May update |server_arch| by properly detecting the architecture and retry + // if the architecture was guessed, i.e. if |server_arch|->IsGuess() is true. + absl::StatusOr FindAvailablePort(ServerArch* server_arch); // Starts the server process. If the method returns a status with tag // |kTagDeployServer|, Run() calls DeployServer() and tries again. diff --git a/cdc_rsync/server_arch.cc b/cdc_rsync/server_arch.cc index 731c0c4..694c231 100644 --- a/cdc_rsync/server_arch.cc +++ b/cdc_rsync/server_arch.cc @@ -20,60 +20,152 @@ #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "common/path.h" -#include "common/platform.h" #include "common/remote_util.h" +#include "common/status_macros.h" #include "common/util.h" namespace cdc_ft { +namespace { -constexpr char kErrorFailedToGetKnownFolderPath[] = - "error_failed_to_get_known_folder_path"; constexpr char kErrorArchTypeUnhandled[] = "arch_type_unhandled"; +constexpr char kUnsupportedArchErrorFmt[] = + "Unsupported remote device architecture '%s'. If you think this is a " + "bug, or if this combination should be supported, please file a bug at " + "https://github.com/google/cdc-file-transfer."; + +absl::StatusOr GetArchTypeFromUname(const std::string& uname_out) { + // uname_out is "KERNEL MACHINE" + // Possible values for KERNEL: Linux (not sure what else). + // Possible values for MACHINE: + // https://stackoverflow.com/questions/45125516/possible-values-for-uname-m + // Relevant for us: x86_64, aarch64. + if (absl::StartsWith(uname_out, "Linux ")) { + // Linux kernel. Check CPU type. + if (absl::StrContains(uname_out, "x86_64")) { + return ArchType::kLinux_x86_64; + } + } + + if (absl::StartsWith(uname_out, "MSYS_")) { + // Windows machine that happens to have Cygwin/MSYS on it. Check CPU type. + if (absl::StrContains(uname_out, "x86_64")) { + return ArchType::kWindows_x86_64; + } + } + + return absl::UnimplementedError( + absl::StrFormat(kUnsupportedArchErrorFmt, uname_out)); +} + +absl::StatusOr GetArchTypeFromWinProcArch( + const std::string& arch_out) { + // Possible values: AMD64, IA64, ARM64, x86 + if (absl::StrContains(arch_out, "AMD64")) { + return ArchType::kWindows_x86_64; + } + + return absl::UnimplementedError( + absl::StrFormat(kUnsupportedArchErrorFmt, arch_out)); +} + +} // namespace // static -ServerArch::Type ServerArch::Detect(const std::string& destination) { +ServerArch ServerArch::GuessFromDestination(const std::string& destination) { // Path starting with ~ or / -> Linux. if (absl::StartsWith(destination, "~") || absl::StartsWith(destination, "/")) { - return Type::kLinux; + LOG_DEBUG("Guessed server arch type Linux based on ~ or /"); + return ServerArch(ArchType::kLinux_x86_64, /*is_guess=*/true); } // Path starting with C: etc. -> Windows. if (!path::GetDrivePrefix(destination).empty()) { - return Type::kWindows; + LOG_DEBUG("Guessed server arch type Windows based on drive prefix"); + return ServerArch(ArchType::kWindows_x86_64, /*is_guess=*/true); } // Path with only / -> Linux. if (absl::StrContains(destination, "/") && !absl::StrContains(destination, "\\")) { - return Type::kLinux; + LOG_DEBUG("Guessed server arch type Linux based on forward slashes"); + return ServerArch(ArchType::kLinux_x86_64, /*is_guess=*/true); } // Path with only \\ -> Windows. if (absl::StrContains(destination, "\\") && !absl::StrContains(destination, "/")) { - return Type::kWindows; + LOG_DEBUG("Guessed server arch type Windows based on backslashes"); + return ServerArch(ArchType::kWindows_x86_64, /*is_guess=*/true); } // Default to Linux. - return Type::kLinux; + LOG_DEBUG("Guessed server arch type Linux as default"); + return ServerArch(ArchType::kLinux_x86_64, /*is_guess=*/true); } // static -ServerArch::Type ServerArch::LocalType() { -#if PLATFORM_WINDOWS - return ServerArch::Type::kWindows; -#elif PLATFORM_LINUX - return ServerArch::Type::kLinux; -#endif +ServerArch ServerArch::DetectFromLocalDevice() { + LOG_DEBUG("Detected local device type %s", + GetArchTypeStr(GetLocalArchType())); + return ServerArch(GetLocalArchType(), /*is_guess=*/false); +} + +// static +absl::StatusOr ServerArch::DetectFromRemoteDevice( + RemoteUtil* remote_util) { + assert(remote_util); + + // Run uname, assuming it's a Linux machine. + std::string uname_out; + std::string linux_cmd = "uname -sm"; + absl::Status status = + remote_util->RunWithCapture(linux_cmd, "uname", &uname_out, nullptr); + if (status.ok()) { + LOG_DEBUG("Uname returned '%s'", uname_out); + absl::StatusOr type = GetArchTypeFromUname(uname_out); + if (type.ok()) { + LOG_DEBUG("Detected server arch type '%s' from uname", + GetArchTypeStr(*type)); + return ServerArch(*type, /*is_guess=*/false); + } + status = type.status(); + } + LOG_DEBUG( + "Failed to detect arch type from uname; this is expected if the remote " + "machine is not Linux; will try Windows next: %s", + status.ToString()); + + // Check %PROCESSOR_ARCHITECTURE%, assuming it's a Windows machine. + // Note: That space after PROCESSOR_ARCHITECTURE is important or else Windows + // command magic interprets quotes as part of the string. + std::string arch_out; + std::string windows_cmd = + RemoteUtil::QuoteForSsh("cmd /C set PROCESSOR_ARCHITECTURE "); + status = remote_util->RunWithCapture( + windows_cmd, "set PROCESSOR_ARCHITECTURE", &arch_out, nullptr); + if (status.ok()) { + LOG_DEBUG("PROCESSOR_ARCHITECTURE is '%s'", arch_out); + absl::StatusOr type = GetArchTypeFromWinProcArch(arch_out); + if (type.ok()) { + LOG_DEBUG("Detected server arch type '%s' from PROCESSOR_ARCHITECTURE", + GetArchTypeStr(*type)); + return ServerArch(*type, /*is_guess=*/false); + } + status = type.status(); + } + LOG_DEBUG("Failed to detect arch type from PROCESSOR_ARCHITECTURE: %s", + status.ToString()); + + return absl::InternalError("Failed to detect remote architecture"); } // static std::string ServerArch::CdcRsyncFilename() { - switch (LocalType()) { - case Type::kWindows: + switch (GetLocalArchType()) { + case ArchType::kWindows_x86_64: return "cdc_rsync.exe"; - case Type::kLinux: + case ArchType::kLinux_x86_64: return "cdc_rsync"; default: assert(!kErrorArchTypeUnhandled); @@ -81,15 +173,18 @@ std::string ServerArch::CdcRsyncFilename() { } } -ServerArch::ServerArch(Type type) : type_(type) {} +ServerArch::ServerArch(ArchType type, bool is_guess) + : type_(type), is_guess_(is_guess) {} ServerArch::~ServerArch() {} +const char* ServerArch::GetTypeStr() const { return GetArchTypeStr(type_); } + std::string ServerArch::CdcServerFilename() const { switch (type_) { - case Type::kWindows: + case ArchType::kWindows_x86_64: return "cdc_rsync_server.exe"; - case Type::kLinux: + case ArchType::kLinux_x86_64: return "cdc_rsync_server"; default: assert(!kErrorArchTypeUnhandled); @@ -98,46 +193,46 @@ std::string ServerArch::CdcServerFilename() const { } std::string ServerArch::RemoteToolsBinDir() const { - switch (type_) { - case Type::kWindows: { - return "AppData\\Roaming\\cdc-file-transfer\\bin\\"; - } - case Type::kLinux: - return ".cache/cdc-file-transfer/bin/"; - default: - assert(!kErrorArchTypeUnhandled); - return std::string(); + if (IsWindowsArchType(type_)) { + return "AppData\\Roaming\\cdc-file-transfer\\bin\\"; } + + if (IsLinuxArchType(type_)) { + return ".cache/cdc-file-transfer/bin/"; + } + + assert(!kErrorArchTypeUnhandled); + return std::string(); } std::string ServerArch::GetStartServerCommand(int exit_code_not_found, const std::string& args) const { std::string server_path = RemoteToolsBinDir() + CdcServerFilename(); - switch (type_) { - case Type::kWindows: - // TODO(ljusten): On Windows, ssh does not seem to forward the Powershell - // exit code (exit_code_not_found) to the process. However, that's really - // a minor issue and means we display "Deploying server..." instead of - // "Server not deployed. Deploying..."; - return RemoteUtil::QuoteForWindows( - absl::StrFormat("powershell -Command \" " - "Set-StrictMode -Version 2; " - "$ErrorActionPreference = 'Stop'; " - "if (-not (Test-Path -Path '%s')) { " - " exit %i; " - "} " - "%s %s " - "\"", - server_path, exit_code_not_found, server_path, args)); - case Type::kLinux: - return absl::StrFormat("if [ ! -f %s ]; then exit %i; fi; %s %s", - server_path, exit_code_not_found, server_path, - args); - default: - assert(!kErrorArchTypeUnhandled); - return std::string(); + if (IsWindowsArchType(type_)) { + // TODO(ljusten): On Windows, ssh does not seem to forward the Powershell + // exit code (exit_code_not_found) to the process. However, that's really + // a minor issue and means we display "Deploying server..." instead of + // "Server not deployed. Deploying..."; + return RemoteUtil::QuoteForWindows( + absl::StrFormat("powershell -Command \" " + "Set-StrictMode -Version 2; " + "$ErrorActionPreference = 'Stop'; " + "if (-not (Test-Path -Path '%s')) { " + " exit %i; " + "} " + "%s %s " + "\"", + server_path, exit_code_not_found, server_path, args)); } + + if (IsLinuxArchType(type_)) { + return absl::StrFormat("if [ ! -f %s ]; then exit %i; fi; %s %s", + server_path, exit_code_not_found, server_path, args); + } + + assert(!kErrorArchTypeUnhandled); + return std::string(); } std::string ServerArch::GetDeploySftpCommands() const { diff --git a/cdc_rsync/server_arch.h b/cdc_rsync/server_arch.h index ab03130..206a2d7 100644 --- a/cdc_rsync/server_arch.h +++ b/cdc_rsync/server_arch.h @@ -19,29 +19,49 @@ #include +#include "absl/status/statusor.h" +#include "common/arch_type.h" + namespace cdc_ft { +class RemoteUtil; + // Abstracts all architecture specifics of cdc_rsync_server deployment. +// Comes in two flavors, "guessed" and "detected". Guesses are used as an +// optimization. For instance, if one syncs to C:\some\path, it's clearly a +// Windows machine and we can skip detection. class ServerArch { public: - enum class Type { - kLinux = 0, - kWindows = 1, - }; - - // Detects the arch type based on the destination path, e.g. path - // starting with C: indicate Windows. - static Type Detect(const std::string& destination); + // Guesses the arch type based on the destination path, e.g. path starting + // with C: indicate Windows. This is a guessed type. It may be wrong. For + // instance, if destination is just a single folder like "foo", the method + // defaults to Type::kLinux. + static ServerArch GuessFromDestination(const std::string& destination); // Returns the arch type that matches the current process's type. - static Type LocalType(); + // This is not a guessed type, it is reliable. + static ServerArch DetectFromLocalDevice(); + + // Creates an by properly detecting it on the remote device. + // This is more costly than guessing, but it is reliable. + static absl::StatusOr DetectFromRemoteDevice( + RemoteUtil* remote_util); // Returns the (local!) arch specific filename of cdc_rsync[.exe]. static std::string CdcRsyncFilename(); - ServerArch(Type type); + ServerArch(ArchType type, bool is_guess); ~ServerArch(); + // Accessor for the arch type. + ArchType GetType() const { return type_; } + + // Returns the type as a human readable string. + const char* GetTypeStr() const; + + // Returns true if the type was guessed and not detected. + bool IsGuess() const { return is_guess_; } + // Returns the arch-specific filename of cdc_rsync_server[.exe]. std::string CdcServerFilename() const; @@ -68,7 +88,8 @@ class ServerArch { std::string GetDeploySftpCommands() const; private: - Type type_; + ArchType type_; + bool is_guess_ = false; }; } // namespace cdc_ft diff --git a/cdc_rsync/server_arch_test.cc b/cdc_rsync/server_arch_test.cc index 6518082..4526af5 100644 --- a/cdc_rsync/server_arch_test.cc +++ b/cdc_rsync/server_arch_test.cc @@ -20,68 +20,83 @@ namespace cdc_ft { namespace { -constexpr auto kLinux = ServerArch::Type::kLinux; -constexpr auto kWindows = ServerArch::Type::kWindows; +constexpr auto kLinux = ArchType::kLinux_x86_64; +constexpr auto kWindows = ArchType::kWindows_x86_64; -TEST(ServerArchTest, DetectsLinuxIfPathStartsWithSlashOrTilde) { - EXPECT_EQ(ServerArch::Detect("/linux/path"), kLinux); - EXPECT_EQ(ServerArch::Detect("/linux\\path"), kLinux); - EXPECT_EQ(ServerArch::Detect("~/linux/path"), kLinux); - EXPECT_EQ(ServerArch::Detect("~/linux\\path"), kLinux); - EXPECT_EQ(ServerArch::Detect("~\\linux\\path"), kLinux); +constexpr bool kNoGuess = false; + +TEST(ServerArchTest, GuessesLinuxIfPathStartsWithSlashOrTilde) { + EXPECT_EQ(ServerArch::GuessFromDestination("/linux/path").GetType(), kLinux); + EXPECT_EQ(ServerArch::GuessFromDestination("/linux\\path").GetType(), kLinux); + EXPECT_EQ(ServerArch::GuessFromDestination("~/linux/path").GetType(), kLinux); + EXPECT_EQ(ServerArch::GuessFromDestination("~/linux\\path").GetType(), + kLinux); + EXPECT_EQ(ServerArch::GuessFromDestination("~\\linux\\path").GetType(), + kLinux); } -TEST(ServerArchTest, DetectsWindowsIfPathStartsWithDrive) { - EXPECT_EQ(ServerArch::Detect("C:\\win\\path"), kWindows); - EXPECT_EQ(ServerArch::Detect("D:win"), kWindows); - EXPECT_EQ(ServerArch::Detect("Z:\\win/path"), kWindows); +TEST(ServerArchTest, GuessesWindowsIfPathStartsWithDrive) { + EXPECT_EQ(ServerArch::GuessFromDestination("C:\\win\\path").GetType(), + kWindows); + EXPECT_EQ(ServerArch::GuessFromDestination("D:win").GetType(), kWindows); + EXPECT_EQ(ServerArch::GuessFromDestination("Z:\\win/path").GetType(), + kWindows); } -TEST(ServerArchTest, DetectsLinuxIfPathOnlyHasForwardSlashes) { - EXPECT_EQ(ServerArch::Detect("linux/path"), kLinux); +TEST(ServerArchTest, GuessesLinuxIfPathOnlyHasForwardSlashes) { + EXPECT_EQ(ServerArch::GuessFromDestination("linux/path").GetType(), kLinux); } -TEST(ServerArchTest, DetectsWindowsIfPathOnlyHasBackSlashes) { - EXPECT_EQ(ServerArch::Detect("\\win\\path"), kWindows); +TEST(ServerArchTest, GuessesWindowsIfPathOnlyHasBackSlashes) { + EXPECT_EQ(ServerArch::GuessFromDestination("\\win\\path").GetType(), + kWindows); } -TEST(ServerArchTest, DetectsLinuxByDefault) { - EXPECT_EQ(ServerArch::Detect("/mixed\\path"), kLinux); - EXPECT_EQ(ServerArch::Detect("/mixed\\path"), kLinux); - EXPECT_EQ(ServerArch::Detect("C\\linux/path"), kLinux); - EXPECT_EQ(ServerArch::Detect(""), kLinux); +TEST(ServerArchTest, GuessesLinuxByDefault) { + EXPECT_EQ(ServerArch::GuessFromDestination("/mixed\\path").GetType(), kLinux); + EXPECT_EQ(ServerArch::GuessFromDestination("/mixed\\path").GetType(), kLinux); + EXPECT_EQ(ServerArch::GuessFromDestination("C\\linux/path").GetType(), + kLinux); + EXPECT_EQ(ServerArch::GuessFromDestination("").GetType(), kLinux); +} + +TEST(ServerArchTest, IsGuess) { + EXPECT_TRUE(ServerArch::GuessFromDestination("foo").IsGuess()); + EXPECT_FALSE(ServerArch::DetectFromLocalDevice().IsGuess()); } TEST(ServerArchTest, CdcServerFilename) { - EXPECT_FALSE( - absl::StrContains(ServerArch(kLinux).CdcServerFilename(), "exe")); - EXPECT_TRUE( - absl::StrContains(ServerArch(kWindows).CdcServerFilename(), "exe")); + EXPECT_FALSE(absl::StrContains( + ServerArch(kLinux, kNoGuess).CdcServerFilename(), "exe")); + EXPECT_TRUE(absl::StrContains( + ServerArch(kWindows, kNoGuess).CdcServerFilename(), "exe")); } TEST(ServerArchTest, RemoteToolsBinDir) { - const std::string linux_dir = ServerArch(kLinux).RemoteToolsBinDir(); + const std::string linux_dir = + ServerArch(kLinux, kNoGuess).RemoteToolsBinDir(); EXPECT_TRUE(absl::StrContains(linux_dir, ".cache/")); - std::string win_dir = ServerArch(kWindows).RemoteToolsBinDir(); + std::string win_dir = ServerArch(kWindows, kNoGuess).RemoteToolsBinDir(); EXPECT_TRUE(absl::StrContains(win_dir, "AppData\\Roaming\\")); } TEST(ServerArchTest, GetStartServerCommand) { - std::string cmd = ServerArch(kWindows).GetStartServerCommand(123, "foo bar"); + std::string cmd = + ServerArch(kWindows, kNoGuess).GetStartServerCommand(123, "foo bar"); EXPECT_TRUE(absl::StrContains(cmd, "123")); EXPECT_TRUE(absl::StrContains(cmd, "cdc_rsync_server.exe foo bar")); - cmd = ServerArch(kLinux).GetStartServerCommand(123, "foo bar"); + cmd = ServerArch(kLinux, kNoGuess).GetStartServerCommand(123, "foo bar"); EXPECT_TRUE(absl::StrContains(cmd, "123")); EXPECT_TRUE(absl::StrContains(cmd, "cdc_rsync_server foo bar")); } TEST(ServerArchTest, GetDeployReplaceCommand) { - std::string cmd = ServerArch(kWindows).GetDeploySftpCommands(); + std::string cmd = ServerArch(kWindows, kNoGuess).GetDeploySftpCommands(); EXPECT_TRUE(absl::StrContains(cmd, "cdc_rsync_server.exe")); - cmd = ServerArch(kLinux).GetDeploySftpCommands(); + cmd = ServerArch(kLinux, kNoGuess).GetDeploySftpCommands(); EXPECT_TRUE(absl::StrContains(cmd, "cdc_rsync_server")); } diff --git a/cdc_stream/multi_session.cc b/cdc_stream/multi_session.cc index bf74c0b..5a0b036 100644 --- a/cdc_stream/multi_session.cc +++ b/cdc_stream/multi_session.cc @@ -441,9 +441,9 @@ absl::Status MultiSession::Initialize() { std::unordered_set ports; ASSIGN_OR_RETURN( ports, - PortManager::FindAvailableLocalPorts(cfg_.forward_port_first, - cfg_.forward_port_last, - "127.0.0.1", process_factory_), + PortManager::FindAvailableLocalPorts( + cfg_.forward_port_first, cfg_.forward_port_last, + ArchType::kWindows_x86_64, process_factory_), "Failed to find an available local port in the range [%d, %d]", cfg_.forward_port_first, cfg_.forward_port_last); assert(!ports.empty()); diff --git a/cdc_stream/session.cc b/cdc_stream/session.cc index e1237c4..81567bf 100644 --- a/cdc_stream/session.cc +++ b/cdc_stream/session.cc @@ -77,8 +77,8 @@ absl::Status Session::Start(int local_port, int first_remote_port, ASSIGN_OR_RETURN( ports, PortManager::FindAvailableRemotePorts( - first_remote_port, last_remote_port, "127.0.0.1", process_factory_, - &remote_util_, kInstanceConnectionTimeoutSec), + first_remote_port, last_remote_port, ArchType::kLinux_x86_64, + process_factory_, &remote_util_, kInstanceConnectionTimeoutSec), "Failed to find an available remote port in the range [%d, %d]", first_remote_port, last_remote_port); assert(!ports.empty()); diff --git a/common/BUILD b/common/BUILD index 1bb244f..1f0faca 100644 --- a/common/BUILD +++ b/common/BUILD @@ -18,6 +18,24 @@ cc_test( ], ) +cc_library( + name = "arch_type", + srcs = ["arch_type.cc"], + hdrs = ["arch_type.h"], + deps = [":platform"], +) + +cc_test( + name = "arch_type_test", + srcs = ["arch_type_test.cc"], + deps = [ + ":arch_type", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "buffer", srcs = ["buffer.cc"], @@ -254,6 +272,7 @@ cc_library( hdrs = ["port_manager.h"], target_compatible_with = ["@platforms//os:windows"], deps = [ + ":arch_type", ":remote_util", ":status", ":stopwatch", @@ -360,6 +379,7 @@ cc_library( srcs = ["remote_util.cc"], hdrs = ["remote_util.h"], deps = [ + ":arch_type", ":platform", ":process", ":sdk_util", diff --git a/common/arch_type.cc b/common/arch_type.cc new file mode 100644 index 0000000..3abc47d --- /dev/null +++ b/common/arch_type.cc @@ -0,0 +1,70 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/arch_type.h" + +#include + +#include "common/platform.h" + +namespace cdc_ft { + +static constexpr char kUnhandledArchType[] = "Unhandled arch type"; + +ArchType GetLocalArchType() { + // TODO(ljusten): Take CPU architecture into account. +#if PLATFORM_WINDOWS + return ArchType::kWindows_x86_64; +#elif PLATFORM_LINUX + return ArchType::kLinux_x86_64; +#endif +} + +bool IsWindowsArchType(ArchType arch_type) { + switch (arch_type) { + case ArchType::kWindows_x86_64: + return true; + case ArchType::kLinux_x86_64: + return false; + default: + assert(!kUnhandledArchType); + return false; + } +} + +bool IsLinuxArchType(ArchType arch_type) { + switch (arch_type) { + case ArchType::kWindows_x86_64: + return false; + case ArchType::kLinux_x86_64: + return true; + default: + assert(!kUnhandledArchType); + return false; + } +} + +const char* GetArchTypeStr(ArchType arch_type) { + switch (arch_type) { + case ArchType::kWindows_x86_64: + return "Windows_x86_64"; + case ArchType::kLinux_x86_64: + return "Linux_x86_64"; + default: + assert(!kUnhandledArchType); + return "Unknown"; + } +} + +} // namespace cdc_ft diff --git a/common/arch_type.h b/common/arch_type.h new file mode 100644 index 0000000..69dfbd0 --- /dev/null +++ b/common/arch_type.h @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_ARCH_TYPE_H_ +#define COMMON_ARCH_TYPE_H_ + +namespace cdc_ft { + +enum class ArchType { + kWindows_x86_64 = 0, + kLinux_x86_64 = 1, +}; + +// Returns the arch type of the current process. +ArchType GetLocalArchType(); + +// Returns true if |arch_type| is a Windows operating system. +bool IsWindowsArchType(ArchType arch_type); + +// Returns true if |arch_type| is a Linux operating system. +bool IsLinuxArchType(ArchType arch_type); + +// Returns a human readable string for |arch_type|. +const char* GetArchTypeStr(ArchType arch_type); + +} // namespace cdc_ft + +#endif // COMMON_ARCH_TYPE_H_ diff --git a/common/arch_type_test.cc b/common/arch_type_test.cc new file mode 100644 index 0000000..e67e6c9 --- /dev/null +++ b/common/arch_type_test.cc @@ -0,0 +1,49 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/arch_type.h" + +#include "absl/strings/match.h" +#include "gtest/gtest.h" + +namespace cdc_ft { +namespace { + +TEST(ArchTypeTest, GetLocalArchType) { +#if PLATFORM_WINDOWS + EXPECT_TRUE(IsPlatformWindows(GetLocalArchType())); +#elif PLATFORM_LINUX + EXPECT_TRUE(IsPlatformLinux(GetLocalArchType())); +#endif +} + +TEST(ArchTypeTest, IsWindowsArchType) { + EXPECT_TRUE(IsWindowsArchType(ArchType::kWindows_x86_64)); + EXPECT_FALSE(IsWindowsArchType(ArchType::kLinux_x86_64)); +} + +TEST(ArchTypeTest, IsLinuxArchType) { + EXPECT_FALSE(IsLinuxArchType(ArchType::kWindows_x86_64)); + EXPECT_TRUE(IsLinuxArchType(ArchType::kLinux_x86_64)); +} + +TEST(ArchTypeTest, GetArchTypeStr) { + EXPECT_TRUE( + absl::StrContains(GetArchTypeStr(ArchType::kWindows_x86_64), "Windows")); + EXPECT_TRUE( + absl::StrContains(GetArchTypeStr(ArchType::kLinux_x86_64), "Linux")); +} + +} // namespace +} // namespace cdc_ft diff --git a/common/port_manager.h b/common/port_manager.h index 95221d5..19627cd 100644 --- a/common/port_manager.h +++ b/common/port_manager.h @@ -17,12 +17,12 @@ #ifndef COMMON_PORT_MANAGER_H_ #define COMMON_PORT_MANAGER_H_ -#include - #include #include #include +#include "absl/status/statusor.h" +#include "common/arch_type.h" #include "common/clock.h" namespace cdc_ft { @@ -53,9 +53,12 @@ class PortManager { // explicitly. // |remote_timeout_sec| is the timeout for finding available ports on the // remote instance. - // Returns a DeadlineExceeded error if the timeout is exceeded. - // Returns a ResourceExhausted error if no ports are available. - absl::StatusOr ReservePort(int remote_timeout_sec); + // |remote_arch_type| is the architecture of the remote device. + // Both |remote_timeout_sec| and |remote_arch_type| are ignored if + // |remote_util| is nullptr. Returns a DeadlineExceeded error if the timeout + // is exceeded. Returns a ResourceExhausted error if no ports are available. + absl::StatusOr ReservePort(int remote_timeout_sec, + ArchType remote_arch_type); // Releases a reserved port. absl::Status ReleasePort(int port); @@ -66,34 +69,34 @@ class PortManager { // Finds available ports in the range [first_port, last_port] for port // forwarding on the local workstation. - // |ip| is the IP address to filter by. + // |arch_type| is the architecture of the local device. // |process_factory| is used to create a netstat process. // Returns ResourceExhaustedError if no port is available. static absl::StatusOr> FindAvailableLocalPorts( - int first_port, int last_port, const char* ip, + int first_port, int last_port, ArchType arch_type, ProcessFactory* process_factory); // Finds available ports in the range [first_port, last_port] for port // forwarding on the instance. - // |ip| is the IP address to filter by. + // |arch_type| is the architecture of the remote device. // |process_factory| is used to create a netstat process. // |remote_util| is used to connect to the instance. // |timeout_sec| is the connection timeout in seconds. // Returns a DeadlineExceeded error if the timeout is exceeded. // Returns ResourceExhaustedError if no port is available. static absl::StatusOr> FindAvailableRemotePorts( - int first_port, int last_port, const char* ip, + int first_port, int last_port, ArchType arch_type, ProcessFactory* process_factory, RemoteUtil* remote_util, int timeout_sec, SteadyClock* steady_clock = DefaultSteadyClock::GetInstance()); private: // Returns a list of available ports in the range [|first_port|, |last_port|] - // from the given |netstat_output|. |ip| is the IP address to look for, e.g. - // "127.0.0.1". + // from the given |netstat_output|. + // |arch_type| is the architecture of the device where netstat was called. // Returns ResourceExhaustedError if no port is available. static absl::StatusOr> FindAvailablePorts( int first_port, int last_port, const std::string& netstat_output, - const char* ip); + ArchType arch_type); int first_port_; int last_port_; diff --git a/common/port_manager_test.cc b/common/port_manager_test.cc index be2c9f5..03e750d 100644 --- a/common/port_manager_test.cc +++ b/common/port_manager_test.cc @@ -25,7 +25,6 @@ namespace cdc_ft { namespace { -constexpr int kSshPort = 12345; constexpr char kUserHost[] = "user@1.2.3.4"; constexpr char kGuid[] = "f77bcdfe-368c-4c45-9f01-230c5e7e2132"; @@ -35,12 +34,12 @@ constexpr int kNumPorts = kLastPort - kFirstPort + 1; constexpr int kTimeoutSec = 1; -constexpr char kLocalNetstat[] = "netstat -a -n -p tcp"; -constexpr char kRemoteNetstat[] = "netstat --numeric --listening --tcp"; +constexpr char kWindowsNetstat[] = "netstat -a -n -p tcp"; +constexpr char kLinuxNetstat[] = "netstat --numeric --listening --tcp"; -constexpr char kLocalNetstatOutFmt[] = +constexpr char kWindowsNetstatOutFmt[] = "TCP 127.0.0.1:50000 127.0.0.1:%i ESTABLISHED"; -constexpr char kRemoteNetstatOutFmt[] = +constexpr char kLinuxNetstatOutFmt[] = "tcp 0 0 0.0.0.0:%i 0.0.0.0:* LISTEN"; class PortManagerTest : public ::testing::Test { @@ -67,10 +66,11 @@ class PortManagerTest : public ::testing::Test { }; TEST_F(PortManagerTest, ReservePortSuccess) { - process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); - process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); + process_factory_.SetProcessOutput(kWindowsNetstat, "", "", 0); + process_factory_.SetProcessOutput(kLinuxNetstat, "", "", 0); - absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); + absl::StatusOr port = + port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64); ASSERT_OK(port); EXPECT_EQ(*port, kFirstPort); } @@ -78,12 +78,13 @@ TEST_F(PortManagerTest, ReservePortSuccess) { TEST_F(PortManagerTest, ReservePortAllLocalPortsTaken) { std::string local_netstat_out = ""; for (int port = kFirstPort; port <= kLastPort; ++port) { - local_netstat_out += absl::StrFormat(kLocalNetstatOutFmt, port); + local_netstat_out += absl::StrFormat(kWindowsNetstatOutFmt, port); } - process_factory_.SetProcessOutput(kLocalNetstat, local_netstat_out, "", 0); - process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); + process_factory_.SetProcessOutput(kWindowsNetstat, local_netstat_out, "", 0); + process_factory_.SetProcessOutput(kLinuxNetstat, "", "", 0); - absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); + absl::StatusOr port = + port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64); EXPECT_TRUE(absl::IsResourceExhausted(port.status())); EXPECT_TRUE( absl::StrContains(port.status().message(), "No port available in range")); @@ -92,22 +93,24 @@ TEST_F(PortManagerTest, ReservePortAllLocalPortsTaken) { TEST_F(PortManagerTest, ReservePortAllRemotePortsTaken) { std::string remote_netstat_out = ""; for (int port = kFirstPort; port <= kLastPort; ++port) { - remote_netstat_out += absl::StrFormat(kRemoteNetstatOutFmt, port); + remote_netstat_out += absl::StrFormat(kLinuxNetstatOutFmt, port); } - process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); - process_factory_.SetProcessOutput(kRemoteNetstat, remote_netstat_out, "", 0); + process_factory_.SetProcessOutput(kWindowsNetstat, "", "", 0); + process_factory_.SetProcessOutput(kLinuxNetstat, remote_netstat_out, "", 0); - absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); + absl::StatusOr port = + port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64); EXPECT_TRUE(absl::IsResourceExhausted(port.status())); EXPECT_TRUE( absl::StrContains(port.status().message(), "No port available in range")); } TEST_F(PortManagerTest, ReservePortLocalNetstatFails) { - process_factory_.SetProcessOutput(kLocalNetstat, "", "", 1); - process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); + process_factory_.SetProcessOutput(kWindowsNetstat, "", "", 1); + process_factory_.SetProcessOutput(kLinuxNetstat, "", "", 0); - absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); + absl::StatusOr port = + port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64); EXPECT_NOT_OK(port); EXPECT_TRUE( absl::StrContains(port.status().message(), @@ -115,21 +118,23 @@ TEST_F(PortManagerTest, ReservePortLocalNetstatFails) { } TEST_F(PortManagerTest, ReservePortRemoteNetstatFails) { - process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); - process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 1); + process_factory_.SetProcessOutput(kWindowsNetstat, "", "", 0); + process_factory_.SetProcessOutput(kLinuxNetstat, "", "", 1); - absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); + absl::StatusOr port = + port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64); EXPECT_NOT_OK(port); EXPECT_TRUE(absl::StrContains(port.status().message(), "Failed to find available ports on instance")); } TEST_F(PortManagerTest, ReservePortRemoteNetstatTimesOut) { - process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); - process_factory_.SetProcessNeverExits(kRemoteNetstat); + process_factory_.SetProcessOutput(kWindowsNetstat, "", "", 0); + process_factory_.SetProcessNeverExits(kLinuxNetstat); steady_clock_.AutoAdvance(kTimeoutSec * 2 * 1000); - absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); + absl::StatusOr port = + port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64); EXPECT_NOT_OK(port); EXPECT_TRUE(absl::IsDeadlineExceeded(port.status())); EXPECT_TRUE(absl::StrContains(port.status().message(), @@ -137,8 +142,8 @@ TEST_F(PortManagerTest, ReservePortRemoteNetstatTimesOut) { } TEST_F(PortManagerTest, ReservePortMultipleInstances) { - process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); - process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); + process_factory_.SetProcessOutput(kWindowsNetstat, "", "", 0); + process_factory_.SetProcessOutput(kLinuxNetstat, "", "", 0); PortManager port_manager2(kGuid, kFirstPort, kLastPort, &process_factory_, &remote_util_); @@ -146,55 +151,79 @@ TEST_F(PortManagerTest, ReservePortMultipleInstances) { // Port managers use shared memory, so different instances know about each // other. This would even work if |port_manager_| and |port_manager2| belonged // to different processes, but we don't test that here. - EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec), kFirstPort + 0); - EXPECT_EQ(*port_manager2.ReservePort(kTimeoutSec), kFirstPort + 1); - EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec), kFirstPort + 2); - EXPECT_EQ(*port_manager2.ReservePort(kTimeoutSec), kFirstPort + 3); + EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64), + kFirstPort + 0); + EXPECT_EQ(*port_manager2.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64), + kFirstPort + 1); + EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64), + kFirstPort + 2); + EXPECT_EQ(*port_manager2.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64), + kFirstPort + 3); } TEST_F(PortManagerTest, ReservePortReusesPortsInLRUOrder) { - process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); - process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); + process_factory_.SetProcessOutput(kWindowsNetstat, "", "", 0); + process_factory_.SetProcessOutput(kLinuxNetstat, "", "", 0); for (int n = 0; n < kNumPorts * 2; ++n) { - EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec), + EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64), kFirstPort + n % kNumPorts); system_clock_.Advance(1000); } } TEST_F(PortManagerTest, ReleasePort) { - process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); - process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); + process_factory_.SetProcessOutput(kWindowsNetstat, "", "", 0); + process_factory_.SetProcessOutput(kLinuxNetstat, "", "", 0); - absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); + absl::StatusOr port = + port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64); EXPECT_EQ(*port, kFirstPort); EXPECT_OK(port_manager_.ReleasePort(*port)); - port = port_manager_.ReservePort(kTimeoutSec); + port = port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64); EXPECT_EQ(*port, kFirstPort); } TEST_F(PortManagerTest, ReleasePortOnDestruction) { - process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); - process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); + process_factory_.SetProcessOutput(kWindowsNetstat, "", "", 0); + process_factory_.SetProcessOutput(kLinuxNetstat, "", "", 0); auto port_manager2 = std::make_unique( kGuid, kFirstPort, kLastPort, &process_factory_, &remote_util_); - EXPECT_EQ(*port_manager2->ReservePort(kTimeoutSec), kFirstPort + 0); - EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec), kFirstPort + 1); + EXPECT_EQ(*port_manager2->ReservePort(kTimeoutSec, ArchType::kLinux_x86_64), + kFirstPort + 0); + EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64), + kFirstPort + 1); port_manager2.reset(); - EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec), kFirstPort + 0); + EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec, ArchType::kLinux_x86_64), + kFirstPort + 0); } -TEST_F(PortManagerTest, FindAvailableLocalPortsSuccess) { - // First port is taken +TEST_F(PortManagerTest, FindAvailableLocalPortsSuccessWindows) { + // First port is in use. std::string local_netstat_out = - absl::StrFormat(kLocalNetstatOutFmt, kFirstPort); - process_factory_.SetProcessOutput(kLocalNetstat, local_netstat_out, "", 0); + absl::StrFormat(kWindowsNetstatOutFmt, kFirstPort); + process_factory_.SetProcessOutput(kWindowsNetstat, local_netstat_out, "", 0); absl::StatusOr> ports = - PortManager::FindAvailableLocalPorts(kFirstPort, kLastPort, "127.0.0.1", - &process_factory_); + PortManager::FindAvailableLocalPorts( + kFirstPort, kLastPort, ArchType::kWindows_x86_64, &process_factory_); + ASSERT_OK(ports); + EXPECT_EQ(ports->size(), kNumPorts - 1); + for (int port = kFirstPort + 1; port <= kLastPort; ++port) { + EXPECT_TRUE(ports->find(port) != ports->end()); + } +} + +TEST_F(PortManagerTest, FindAvailableLocalPortsSuccessLinux) { + // First port is in use. + std::string local_netstat_out = + absl::StrFormat(kLinuxNetstatOutFmt, kFirstPort); + process_factory_.SetProcessOutput(kLinuxNetstat, local_netstat_out, "", 0); + + absl::StatusOr> ports = + PortManager::FindAvailableLocalPorts( + kFirstPort, kLastPort, ArchType::kLinux_x86_64, &process_factory_); ASSERT_OK(ports); EXPECT_EQ(ports->size(), kNumPorts - 1); for (int port = kFirstPort + 1; port <= kLastPort; ++port) { @@ -203,31 +232,48 @@ TEST_F(PortManagerTest, FindAvailableLocalPortsSuccess) { } TEST_F(PortManagerTest, FindAvailableLocalPortsFailsNoPorts) { - // All ports taken + // All ports are in use. std::string local_netstat_out = ""; for (int port = kFirstPort; port <= kLastPort; ++port) { - local_netstat_out += absl::StrFormat(kLocalNetstatOutFmt, port); + local_netstat_out += absl::StrFormat(kWindowsNetstatOutFmt, port); } - process_factory_.SetProcessOutput(kLocalNetstat, local_netstat_out, "", 0); + process_factory_.SetProcessOutput(kWindowsNetstat, local_netstat_out, "", 0); absl::StatusOr> ports = - PortManager::FindAvailableLocalPorts(kFirstPort, kLastPort, "127.0.0.1", - &process_factory_); + PortManager::FindAvailableLocalPorts( + kFirstPort, kLastPort, ArchType::kWindows_x86_64, &process_factory_); EXPECT_TRUE(absl::IsResourceExhausted(ports.status())); EXPECT_TRUE(absl::StrContains(ports.status().message(), "No port available in range")); } -TEST_F(PortManagerTest, FindAvailableRemotePortsSuccess) { - // First port is taken +TEST_F(PortManagerTest, FindAvailableRemotePortsSuccessLinux) { + // First port is in use. std::string remote_netstat_out = - absl::StrFormat(kRemoteNetstatOutFmt, kFirstPort); - process_factory_.SetProcessOutput(kRemoteNetstat, remote_netstat_out, "", 0); + absl::StrFormat(kLinuxNetstatOutFmt, kFirstPort); + process_factory_.SetProcessOutput(kLinuxNetstat, remote_netstat_out, "", 0); absl::StatusOr> ports = - PortManager::FindAvailableRemotePorts(kFirstPort, kLastPort, "0.0.0.0", - &process_factory_, &remote_util_, - kTimeoutSec); + PortManager::FindAvailableRemotePorts( + kFirstPort, kLastPort, ArchType::kLinux_x86_64, &process_factory_, + &remote_util_, kTimeoutSec); + ASSERT_OK(ports); + EXPECT_EQ(ports->size(), kNumPorts - 1); + for (int port = kFirstPort + 1; port <= kLastPort; ++port) { + EXPECT_TRUE(ports->find(port) != ports->end()); + } +} + +TEST_F(PortManagerTest, FindAvailableRemotePortsSuccessWindows) { + // First port is in use. + std::string remote_netstat_out = + absl::StrFormat(kWindowsNetstatOutFmt, kFirstPort); + process_factory_.SetProcessOutput(kWindowsNetstat, remote_netstat_out, "", 0); + + absl::StatusOr> ports = + PortManager::FindAvailableRemotePorts( + kFirstPort, kLastPort, ArchType::kWindows_x86_64, &process_factory_, + &remote_util_, kTimeoutSec); ASSERT_OK(ports); EXPECT_EQ(ports->size(), kNumPorts - 1); for (int port = kFirstPort + 1; port <= kLastPort; ++port) { @@ -236,17 +282,17 @@ TEST_F(PortManagerTest, FindAvailableRemotePortsSuccess) { } TEST_F(PortManagerTest, FindAvailableRemotePortsFailsNoPorts) { - // All ports taken + // All ports are in use. std::string remote_netstat_out = ""; for (int port = kFirstPort; port <= kLastPort; ++port) { - remote_netstat_out += absl::StrFormat(kRemoteNetstatOutFmt, port); + remote_netstat_out += absl::StrFormat(kLinuxNetstatOutFmt, port); } - process_factory_.SetProcessOutput(kRemoteNetstat, remote_netstat_out, "", 0); + process_factory_.SetProcessOutput(kLinuxNetstat, remote_netstat_out, "", 0); absl::StatusOr> ports = - PortManager::FindAvailableRemotePorts(kFirstPort, kLastPort, "0.0.0.0", - &process_factory_, &remote_util_, - kTimeoutSec); + PortManager::FindAvailableRemotePorts( + kFirstPort, kLastPort, ArchType::kLinux_x86_64, &process_factory_, + &remote_util_, kTimeoutSec); EXPECT_TRUE(absl::IsResourceExhausted(ports.status())); EXPECT_TRUE(absl::StrContains(ports.status().message(), "No port available in range")); diff --git a/common/port_manager_win.cc b/common/port_manager_win.cc index cec3396..ec22354 100644 --- a/common/port_manager_win.cc +++ b/common/port_manager_win.cc @@ -20,6 +20,7 @@ #include #include "absl/strings/str_split.h" +#include "common/arch_type.h" #include "common/log.h" #include "common/process.h" #include "common/remote_util.h" @@ -30,6 +31,42 @@ namespace cdc_ft { +constexpr char kErrorArchTypeUnhandled[] = "arch_type_unhandled"; + +// Returns the arch-specific netstat command. +const char* GetNetstatCommand(ArchType arch_type) { + if (IsWindowsArchType(arch_type)) { + // -a to get the connection and ports the computer is listening on. + // -n to get numerical addresses to avoid the overhead of getting names. + // -p tcp to limit the output to TCPv4 connections. + return "netstat -a -n -p tcp"; + } + + if (IsLinuxArchType(arch_type)) { + // --numeric to get numerical addresses. + // --listening to get only listening sockets. + // --tcp to get only TCP connections. + return "netstat --numeric --listening --tcp"; + } + + assert(!kErrorArchTypeUnhandled); + return ""; +} + +// Returns the arch-specific IP address to filter netstat results by. +const char* GetNetstatFilterIp(ArchType arch_type) { + if (IsWindowsArchType(arch_type)) { + return "127.0.0.1"; + } + + if (IsLinuxArchType(arch_type)) { + return "0.0.0.0"; + } + + assert(!kErrorArchTypeUnhandled); + return ""; +} + class SharedMemory { public: // Creates a new shared memory instance with given |name| and |size| in bytes. @@ -121,22 +158,25 @@ PortManager::~PortManager() { } } -absl::StatusOr PortManager::ReservePort(int remote_timeout_sec) { +absl::StatusOr PortManager::ReservePort(int remote_timeout_sec, + ArchType remote_arch_type) { // Find available port on workstation. std::unordered_set local_ports; - ASSIGN_OR_RETURN(local_ports, - FindAvailableLocalPorts(first_port_, last_port_, "127.0.0.1", - process_factory_), - "Failed to find available ports on workstation"); + ASSIGN_OR_RETURN( + local_ports, + FindAvailableLocalPorts(first_port_, last_port_, + ArchType::kWindows_x86_64, process_factory_), + "Failed to find available ports on workstation"); // Find available port on remote instance. std::unordered_set remote_ports = local_ports; if (remote_util_ != nullptr) { - ASSIGN_OR_RETURN(remote_ports, - FindAvailableRemotePorts( - first_port_, last_port_, "0.0.0.0", process_factory_, - remote_util_, remote_timeout_sec, steady_clock_), - "Failed to find available ports on instance"); + ASSIGN_OR_RETURN( + remote_ports, + FindAvailableRemotePorts(first_port_, last_port_, remote_arch_type, + process_factory_, remote_util_, + remote_timeout_sec, steady_clock_), + "Failed to find available ports on instance"); } // Fetch shared memory. @@ -203,14 +243,11 @@ absl::Status PortManager::ReleasePort(int port) { // static absl::StatusOr> PortManager::FindAvailableLocalPorts( - int first_port, int last_port, const char* ip, + int first_port, int last_port, ArchType arch_type, ProcessFactory* process_factory) { - // -a to get the connection and ports the computer is listening on. - // -n to get numerical addresses to avoid the overhead of determining names. - // -p tcp to limit the output to TCPv4 connections. - // TODO: Use Windows API instead of netstat. + // TODO: Use local APIs instead of netstat. ProcessStartInfo start_info; - start_info.command = "netstat -a -n -p tcp"; + start_info.command = GetNetstatCommand(arch_type); start_info.name = "netstat"; start_info.flags = ProcessFlags::kNoWindow; @@ -231,18 +268,15 @@ absl::StatusOr> PortManager::FindAvailableLocalPorts( } LOG_DEBUG("netstat (workstation) output:\n%s", output); - return FindAvailablePorts(first_port, last_port, output, ip); + return FindAvailablePorts(first_port, last_port, output, arch_type); } // static absl::StatusOr> PortManager::FindAvailableRemotePorts( - int first_port, int last_port, const char* ip, + int first_port, int last_port, ArchType arch_type, ProcessFactory* process_factory, RemoteUtil* remote_util, int timeout_sec, SteadyClock* steady_clock) { - // --numeric to get numerical addresses. - // --listening to get only listening sockets. - // --tcp to get only TCP connections. - std::string remote_command = "netstat --numeric --listening --tcp"; + std::string remote_command = GetNetstatCommand(arch_type); ProcessStartInfo start_info = remote_util->BuildProcessStartInfoForSsh(remote_command); start_info.name = "netstat"; @@ -281,24 +315,25 @@ absl::StatusOr> PortManager::FindAvailableRemotePorts( } LOG_DEBUG("netstat (instance) output:\n%s", output); - return FindAvailablePorts(first_port, last_port, output, ip); + return FindAvailablePorts(first_port, last_port, output, arch_type); } // static absl::StatusOr> PortManager::FindAvailablePorts( int first_port, int last_port, const std::string& netstat_output, - const char* ip) { + ArchType arch_type) { std::unordered_set available_ports; std::vector lines; + const char* filter_ip = GetNetstatFilterIp(arch_type); for (const auto& line : absl::StrSplit(netstat_output, '\n')) { - if (absl::StrContains(line, ip)) { + if (absl::StrContains(line, filter_ip)) { lines.push_back(std::string(line)); } } for (int port = first_port; port <= last_port; ++port) { bool port_occupied = false; - std::string portToken = absl::StrFormat("%s:%i", ip, port); + std::string portToken = absl::StrFormat("%s:%i", filter_ip, port); for (const std::string& line : lines) { // Ports in the TIME_WAIT state can be reused. It is common that ports // stay in this state for O(minutes). diff --git a/common/remote_util.cc b/common/remote_util.cc index bfeb706..4db64d1 100644 --- a/common/remote_util.cc +++ b/common/remote_util.cc @@ -148,6 +148,31 @@ absl::Status RemoteUtil::Run(std::string remote_command, std::string name) { return process_factory_->Run(start_info); } +absl::Status RemoteUtil::RunWithCapture(std::string remote_command, + std::string name, std::string* std_out, + std::string* std_err) { + ProcessStartInfo start_info = + BuildProcessStartInfoForSsh(std::move(remote_command)); + start_info.name = std::move(name); + start_info.forward_output_to_log = forward_output_to_log_; + + if (std_out) { + start_info.stdout_handler = [std_out](const char* data, size_t size) { + std_out->append(data, size); + return absl::OkStatus(); + }; + } + + if (std_err) { + start_info.stderr_handler = [std_err](const char* data, size_t size) { + std_err->append(data, size); + return absl::OkStatus(); + }; + } + + return process_factory_->Run(start_info); +} + ProcessStartInfo RemoteUtil::BuildProcessStartInfoForSsh( std::string remote_command) { return BuildProcessStartInfoForSshInternal("", "-- " + remote_command); diff --git a/common/remote_util.h b/common/remote_util.h index cc4d0d0..94cb300 100644 --- a/common/remote_util.h +++ b/common/remote_util.h @@ -25,7 +25,7 @@ namespace cdc_ft { -// Utilities for executing remote commands on a gamelet through SSH. +// Utilities for executing remote commands on a remote device through SSH. // Windows-only. class RemoteUtil { public: @@ -63,8 +63,8 @@ class RemoteUtil { // Returns bad results for tricky strings like "C:\scp.path\scp.exe". static std::string ScpToSftpCommand(std::string scp_command); - // Copies |source_filepaths| to the remote folder |dest| on the gamelet using - // scp. If |compress| is true, compressed upload is used. + // Copies |source_filepaths| to the remote folder |dest| on the remove device + // using scp. If |compress| is true, compressed upload is used. absl::Status Scp(std::vector source_filepaths, const std::string& dest, bool compress); @@ -86,27 +86,32 @@ class RemoteUtil { absl::Status Sftp(const std::string& commands, const std::string& initial_local_dir, bool compress); - // Calls 'chmod |mode| |remote_path|' on the gamelet. + // Calls 'chmod |mode| |remote_path|' on the remote device. absl::Status Chmod(const std::string& mode, const std::string& remote_path, bool quiet = false); - // Runs |remote_command| on the gamelet. The command must be properly escaped. - // |name| is the name of the command displayed in the logs. + // Runs |remote_command| on the remote device. The command must be properly + // escaped. |name| is the name of the command displayed in the logs. absl::Status Run(std::string remote_command, std::string name); - // Builds an SSH command that executes |remote_command| on the gamelet. + // Same as Run(), but captures both stdout and stderr. + // If |std_out| or |std_err| are nullptr, the output is not captured. + absl::Status RunWithCapture(std::string remote_command, std::string name, + std::string* std_out, std::string* std_err); + + // Builds an SSH command that executes |remote_command| on the remote device. ProcessStartInfo BuildProcessStartInfoForSsh(std::string remote_command); - // Builds an SSH command that runs SSH port forwarding to the gamelet, using - // the given |local_port| and |remote_port|. - // If |reverse| is true, sets up reverse port forwarding. + // Builds an SSH command that runs SSH port forwarding to the remote device, + // using the given |local_port| and |remote_port|. If |reverse| is true, sets + // up reverse port forwarding. ProcessStartInfo BuildProcessStartInfoForSshPortForward(int local_port, int remote_port, bool reverse); - // Builds an SSH command that executes |remote_command| on the gamelet, using - // port forwarding with given |local_port| and |remote_port|. - // If |reverse| is true, sets up reverse port forwarding. + // Builds an SSH command that executes |remote_command| on the remote device, + // using port forwarding with given |local_port| and |remote_port|. If + // |reverse| is true, sets up reverse port forwarding. ProcessStartInfo BuildProcessStartInfoForSshPortForwardAndCommand( int local_port, int remote_port, bool reverse, std::string remote_command); diff --git a/tests_common/BUILD b/tests_common/BUILD index 54f03c9..30b6ec9 100644 --- a/tests_common/BUILD +++ b/tests_common/BUILD @@ -22,6 +22,7 @@ cc_binary( ], deps = [ "//common:ansi_filter", + "//common:arch_type", "//common:buffer", "//common:dir_iter", "//common:file_watcher",