diff --git a/all_files.vcxitems b/all_files.vcxitems index f7f0f20..240de11 100644 --- a/all_files.vcxitems +++ b/all_files.vcxitems @@ -34,6 +34,8 @@ + + @@ -144,6 +146,7 @@ + diff --git a/cdc_rsync/BUILD b/cdc_rsync/BUILD index b249410..5734cec 100644 --- a/cdc_rsync/BUILD +++ b/cdc_rsync/BUILD @@ -130,6 +130,7 @@ cc_library( hdrs = ["params.h"], deps = [ ":cdc_rsync_client", + "//common:port_range_parser", "@com_github_zstd//:zstd", "@com_google_absl//absl/status", ], diff --git a/cdc_rsync/cdc_rsync_client.cc b/cdc_rsync/cdc_rsync_client.cc index 0efacf9..1dccd6c 100644 --- a/cdc_rsync/cdc_rsync_client.cc +++ b/cdc_rsync/cdc_rsync_client.cc @@ -44,8 +44,6 @@ constexpr int kExitCodeCouldNotExecute = 126; // Bash exit code if binary was not found. constexpr int kExitCodeNotFound = 127; -constexpr int kForwardPortFirst = 44450; -constexpr int kForwardPortLast = 44459; constexpr char kCdcServerFilename[] = "cdc_rsync_server"; constexpr char kRemoteToolsBinDir[] = "~/.cache/cdc-file-transfer/bin/"; @@ -104,8 +102,8 @@ CdcRsyncClient::CdcRsyncClient(const Options& options, &process_factory_, /*forward_output_to_log=*/false), port_manager_("cdc_rsync_ports_f77bcdfe-368c-4c45-9f01-230c5e7e2132", - kForwardPortFirst, kForwardPortLast, &process_factory_, - &remote_util_), + options.forward_port_first, options.forward_port_last, + &process_factory_, &remote_util_), printer_(options.quiet, Util::IsTTY() && !options.json), progress_(&printer_, options.verbosity, options.json) { if (!options_.ssh_command.empty()) { @@ -184,19 +182,24 @@ absl::Status CdcRsyncClient::StartServer() { std::string component_args = GameletComponent::ToCommandLineArgs(components); // Find available local and remote ports for port forwarding. - absl::StatusOr port_res = port_manager_.ReservePort( - /*check_remote=*/false, /*remote_timeout_sec unused*/ 0); - constexpr char kErrorMsg[] = "Failed to find available port"; - if (absl::IsDeadlineExceeded(port_res.status())) { - // Server didn't respond in time. - return SetTag(WrapStatus(port_res.status(), kErrorMsg), - Tag::kConnectionTimeout); + // If only one port is in the given range, try that without checking. + int port = options_.forward_port_first; + if (options_.forward_port_first < options_.forward_port_last) { + absl::StatusOr port_res = + port_manager_.ReservePort(options_.connection_timeout_sec); + constexpr char kErrorMsg[] = "Failed to find available port"; + if (absl::IsDeadlineExceeded(port_res.status())) { + // Server didn't respond in time. + return SetTag(WrapStatus(port_res.status(), kErrorMsg), + Tag::kConnectionTimeout); + } + if (absl::IsResourceExhausted(port_res.status())) + return SetTag(WrapStatus(port_res.status(), kErrorMsg), + Tag::kAddressInUse); + if (!port_res.ok()) + return WrapStatus(port_res.status(), "Failed to find available port"); + port = *port_res; } - if (absl::IsResourceExhausted(port_res.status())) - return SetTag(WrapStatus(port_res.status(), kErrorMsg), Tag::kAddressInUse); - if (!port_res.ok()) - return WrapStatus(port_res.status(), "Failed to find available port"); - int port = *port_res; std::string remote_server_path = std::string(kRemoteToolsBinDir) + kCdcServerFilename; diff --git a/cdc_rsync/cdc_rsync_client.h b/cdc_rsync/cdc_rsync_client.h index ecc73d1..9e24b78 100644 --- a/cdc_rsync/cdc_rsync_client.h +++ b/cdc_rsync/cdc_rsync_client.h @@ -50,6 +50,8 @@ class CdcRsyncClient { std::string copy_dest; int compress_level = 6; int connection_timeout_sec = 10; + int forward_port_first = 44450; + int forward_port_last = 44459; std::string ssh_command; std::string scp_command; std::string sources_dir; // Base dir for files loaded for --files-from. diff --git a/cdc_rsync/params.cc b/cdc_rsync/params.cc index 7a33620..241f6a7 100644 --- a/cdc_rsync/params.cc +++ b/cdc_rsync/params.cc @@ -20,6 +20,7 @@ #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "common/path.h" +#include "common/port_range_parser.h" #include "lib/zstd.h" namespace cdc_ft { @@ -45,39 +46,41 @@ Usage: cdc_rsync [options] source [source]... [user@]host:destination Parameters: - source Local file or directory to be copied - user Remote SSH user name - host Remote host or IP address - destination Remote destination directory + source Local file or directory to be copied + user Remote SSH user name + host Remote host or IP address + destination Remote destination directory Options: - --contimeout sec Gamelet connection timeout in seconds (default: 10) --q, --quiet Quiet mode, only print errors --v, --verbose Increase output verbosity - --json Print JSON progress --n, --dry-run Perform a trial run with no changes made --r, --recursive Recurse into directories - --delete Delete extraneous files from destination directory --z, --compress Compress file data during the transfer - --compress-level num Explicitly set compression level (default: 6) --c, --checksum Skip files based on checksum, not mod-time & size --W, --whole-file Always copy files whole, - do not apply delta-transfer algorithm - --exclude pattern Exclude files matching pattern - --exclude-from file Read exclude patterns from file - --include pattern Don't exclude files matching pattern - --include-from file Read include patterns from file - --files-from file Read list of source files from file --R, --relative Use relative path names - --existing Skip creating new files on instance - --copy-dest dir Use files from dir as sync base if files are missing - --ssh-command Path and arguments of ssh command to use, e.g. - "C:\path\to\ssh.exe -p 12345 -i id_rsa -oUserKnownHostsFile=known_hosts" - Can also be specified by the CDC_SSH_COMMAND environment variable. - --scp-command Path and arguments of scp command to use, e.g. - "C:\path\to\scp.exe -P 12345 -i id_rsa -oUserKnownHostsFile=known_hosts" - Can also be specified by the CDC_SCP_COMMAND environment variable. --h --help Help for cdc_rsync + --contimeout sec Gamelet connection timeout in seconds (default: 10) +-q, --quiet Quiet mode, only print errors +-v, --verbose Increase output verbosity + --json Print JSON progress +-n, --dry-run Perform a trial run with no changes made +-r, --recursive Recurse into directories + --delete Delete extraneous files from destination directory +-z, --compress Compress file data during the transfer + --compress-level Explicitly set compression level (default: 6) +-c, --checksum Skip files based on checksum, not mod-time & size +-W, --whole-file Always copy files whole, + do not apply delta-transfer algorithm + --exclude pattern Exclude files matching pattern + --exclude-from Read exclude patterns from file + --include pattern Don't exclude files matching pattern + --include-from Read include patterns from file + --files-from Read list of source files from file +-R, --relative Use relative path names + --existing Skip creating new files on instance + --copy-dest Use files from dir as sync base if files are missing + --ssh-command Path and arguments of ssh command to use, e.g. + "C:\path\to\ssh.exe -p 12345 -i id_rsa -oUserKnownHostsFile=known_hosts" + Can also be specified by the CDC_SSH_COMMAND environment variable. + --scp-command Path and arguments of scp command to use, e.g. + "C:\path\to\scp.exe -P 12345 -i id_rsa -oUserKnownHostsFile=known_hosts" + Can also be specified by the CDC_SCP_COMMAND environment variable. + --forward-port TCP port or range used for SSH port forwarding (default: 44450-44459). + If a range is specified, searches for available ports (slower). +-h --help Help for cdc_rsync )"; constexpr char kSshCommandEnvVar[] = "CDC_SSH_COMMAND"; @@ -91,15 +94,20 @@ void PopulateFromEnvVars(Parameters* parameters) { .IgnoreError(); } +// Returns false and prints an error if |value| is null or empty. +bool ValidateValue(const std::string& option_name, const char* value) { + if (!value) { + PrintError("Option '%s' needs a value", option_name); + return false; + } + return true; +} + // Handles the --exclude-from and --include-from options. OptionResult HandleFilterRuleFile(const std::string& option_name, const char* path, PathFilter::Rule::Type type, Parameters* params) { - if (!path) { - PrintError("Option '%s' needs a value", option_name); - return OptionResult::kError; - } - + assert(path); std::vector patterns; absl::Status status = path::ReadAllLines( path, &patterns, @@ -188,29 +196,34 @@ OptionResult HandleParameter(const std::string& key, const char* value, } if (key == "include") { + if (!ValidateValue(key, value)) return OptionResult::kError; params->options.filter.AddRule(PathFilter::Rule::Type::kInclude, value); return OptionResult::kConsumedKeyValue; } if (key == "include-from") { + if (!ValidateValue(key, value)) return OptionResult::kError; return HandleFilterRuleFile(key, value, PathFilter::Rule::Type::kInclude, params); } if (key == "exclude") { + if (!ValidateValue(key, value)) return OptionResult::kError; params->options.filter.AddRule(PathFilter::Rule::Type::kExclude, value); return OptionResult::kConsumedKeyValue; } if (key == "exclude-from") { + if (!ValidateValue(key, value)) return OptionResult::kError; return HandleFilterRuleFile(key, value, PathFilter::Rule::Type::kExclude, params); } if (key == "files-from") { // Implies -R. + if (!ValidateValue(key, value)) return OptionResult::kError; params->options.relative = true; - params->files_from = value ? value : std::string(); + params->files_from = value; return OptionResult::kConsumedKeyValue; } @@ -225,16 +238,14 @@ OptionResult HandleParameter(const std::string& key, const char* value, } if (key == "compress-level") { - if (value) { - params->options.compress_level = atoi(value); - } + if (!ValidateValue(key, value)) return OptionResult::kError; + params->options.compress_level = atoi(value); return OptionResult::kConsumedKeyValue; } if (key == "contimeout") { - if (value) { - params->options.connection_timeout_sec = atoi(value); - } + if (!ValidateValue(key, value)) return OptionResult::kError; + params->options.connection_timeout_sec = atoi(value); return OptionResult::kConsumedKeyValue; } @@ -259,7 +270,8 @@ OptionResult HandleParameter(const std::string& key, const char* value, } if (key == "copy-dest") { - params->options.copy_dest = value ? value : std::string(); + if (!ValidateValue(key, value)) return OptionResult::kError; + params->options.copy_dest = value; return OptionResult::kConsumedKeyValue; } @@ -269,12 +281,27 @@ OptionResult HandleParameter(const std::string& key, const char* value, } if (key == "ssh-command") { - params->options.ssh_command = value ? value : std::string(); + if (!ValidateValue(key, value)) return OptionResult::kError; + params->options.ssh_command = value; return OptionResult::kConsumedKeyValue; } if (key == "scp-command") { - params->options.scp_command = value ? value : std::string(); + if (!ValidateValue(key, value)) return OptionResult::kError; + params->options.scp_command = value; + return OptionResult::kConsumedKeyValue; + } + + if (key == "forward-port") { + if (!ValidateValue(key, value)) return OptionResult::kError; + uint16_t first, last; + if (!port_range::Parse(value, &first, &last)) { + PrintError("Failed to parse %s=%s, expected or -", + key, value); + return OptionResult::kError; + } + params->options.forward_port_first = first; + params->options.forward_port_last = last; return OptionResult::kConsumedKeyValue; } @@ -355,11 +382,7 @@ bool CheckOptionResult(OptionResult result, const std::string& name, return true; case OptionResult::kConsumedKeyValue: - if (!value) { - PrintError("Option '%s' needs a value", name); - return false; - } - return true; + return ValidateValue(name, value); case OptionResult::kError: // Error message was already printed. diff --git a/cdc_rsync/params_test.cc b/cdc_rsync/params_test.cc index e038df0..24ae2fa 100644 --- a/cdc_rsync/params_test.cc +++ b/cdc_rsync/params_test.cc @@ -536,6 +536,40 @@ TEST_F(ParamsTest, IncludeExcludeMixed_ProperOrder) { ExpectNoError(); } +TEST_F(ParamsTest, ForwardPort_Single) { + const char* argv[] = {"cdc_rsync.exe", "--forward-port=65535", kSrc, + kUserHostDst, NULL}; + EXPECT_TRUE(Parse(static_cast(std::size(argv)) - 1, argv, ¶meters_)); + EXPECT_EQ(parameters_.options.forward_port_first, 65535); + EXPECT_EQ(parameters_.options.forward_port_last, 65535); + ExpectNoError(); +} + +TEST_F(ParamsTest, ForwardPort_Range) { + const char* argv[] = { + "cdc_rsync.exe", "--forward-port", "1-2", kSrc, kUserHostDst, NULL}; + EXPECT_TRUE(Parse(static_cast(std::size(argv)) - 1, argv, ¶meters_)); + EXPECT_EQ(parameters_.options.forward_port_first, 1); + EXPECT_EQ(parameters_.options.forward_port_last, 2); + ExpectNoError(); +} + +TEST_F(ParamsTest, ForwardPort_NoValue) { + const char* argv[] = {"cdc_rsync.exe", "--forward-port=", kSrc, kUserHostDst, + NULL}; + EXPECT_FALSE( + Parse(static_cast(std::size(argv)) - 1, argv, ¶meters_)); + ExpectError(NeedsValueError("forward-port")); +} + +TEST_F(ParamsTest, ForwardPort_BadValueTooSmall) { + const char* argv[] = {"cdc_rsync.exe", "--forward-port=0", kSrc, kUserHostDst, + NULL}; + EXPECT_FALSE( + Parse(static_cast(std::size(argv)) - 1, argv, ¶meters_)); + ExpectError("Failed to parse"); +} + } // namespace } // namespace params } // namespace cdc_ft diff --git a/cdc_stream/BUILD b/cdc_stream/BUILD index a12c381..c1e7fb6 100644 --- a/cdc_stream/BUILD +++ b/cdc_stream/BUILD @@ -24,6 +24,7 @@ cc_library( hdrs = ["base_command.h"], deps = [ "//absl_helper:jedec_size_flag", + "//common:port_range_parser", "@com_github_lyra//:lyra", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", diff --git a/cdc_stream/asset_stream_config.cc b/cdc_stream/asset_stream_config.cc index af38e95..41995df 100644 --- a/cdc_stream/asset_stream_config.cc +++ b/cdc_stream/asset_stream_config.cc @@ -20,6 +20,7 @@ #include "absl/strings/str_join.h" #include "absl_helper/jedec_size_flag.h" #include "cdc_stream/base_command.h" +#include "cdc_stream/multi_session.h" #include "cdc_stream/session_management_server.h" #include "common/buffer.h" #include "common/path.h" @@ -49,6 +50,20 @@ void AssetStreamConfig::RegisterCommandLineFlags(lyra::command& cmd, "asset stream service, default: " + std::to_string(service_port_))); + session_cfg_.forward_port_first = MultiSession::kDefaultForwardPortFirst; + session_cfg_.forward_port_last = MultiSession::kDefaultForwardPortLast; + cmd.add_argument( + lyra::opt(base_command.PortRangeParser("--forward-port", + &session_cfg_.forward_port_first, + &session_cfg_.forward_port_last), + "port") + .name("--forward-port") + .help("TCP port or range used for SSH port forwarding, default: " + + std::to_string(MultiSession::kDefaultForwardPortFirst) + "-" + + std::to_string(MultiSession::kDefaultForwardPortLast) + + ". If a range is specified, searches for available ports " + "(slower).")); + session_cfg_.verbosity = kDefaultVerbosity; cmd.add_argument(lyra::opt(session_cfg_.verbosity, "num") .name("--verbosity") @@ -175,6 +190,8 @@ absl::Status AssetStreamConfig::LoadFromFile(const std::string& path) { } while (0) ASSIGN_VAR(service_port_, "service-port", Int); + ASSIGN_VAR(session_cfg_.forward_port_first, "forward-port-first", Int); + ASSIGN_VAR(session_cfg_.forward_port_last, "forward-port-last", Int); ASSIGN_VAR(session_cfg_.verbosity, "verbosity", Int); ASSIGN_VAR(session_cfg_.fuse_debug, "debug", Bool); ASSIGN_VAR(session_cfg_.fuse_singlethreaded, "singlethreaded", Bool); @@ -214,6 +231,8 @@ absl::Status AssetStreamConfig::LoadFromFile(const std::string& path) { std::string AssetStreamConfig::ToString() { std::ostringstream ss; ss << "service-port = " << service_port_ << std::endl; + ss << "forward-port = " << session_cfg_.forward_port_first + << "-" << session_cfg_.forward_port_last << std::endl; ss << "verbosity = " << session_cfg_.verbosity << std::endl; ss << "debug = " << session_cfg_.fuse_debug diff --git a/cdc_stream/asset_stream_config.h b/cdc_stream/asset_stream_config.h index 3e24e07..059cd00 100644 --- a/cdc_stream/asset_stream_config.h +++ b/cdc_stream/asset_stream_config.h @@ -48,18 +48,21 @@ class AssetStreamConfig { // Loads a configuration from the JSON file at |path| and overrides any config // values that are set in this file. Sample json file: // { + // "service-port":44432 + // "forward-port-first":"44433" + // "forward-port-last":"44442" // "verbosity":3, // "debug":0, // "singlethreaded":0, // "stats":0, // "quiet":0, // "check":0, - // "log_to_stdout":0, - // "cache_capacity":"150G", - // "cleanup_timeout":300, - // "access_idle_timeout":5, - // "manifest_updater_threads":4, - // "file_change_wait_duration_ms":500 + // "log-to-stdout":0, + // "cache-capacity":"150G", + // "cleanup-timeout":300, + // "access-idle-timeout":5, + // "manifest-updater-threads":4, + // "file-change-wait-duration-ms":500 // } // Returns NotFoundError if the file does not exist. // Returns InvalidArgumentError if the file is not valid JSON. diff --git a/cdc_stream/base_command.cc b/cdc_stream/base_command.cc index 31808c8..6f0ff33 100644 --- a/cdc_stream/base_command.cc +++ b/cdc_stream/base_command.cc @@ -15,7 +15,9 @@ #include "cdc_stream/base_command.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" #include "absl_helper/jedec_size_flag.h" +#include "common/port_range_parser.h" #include "lyra/lyra.hpp" namespace cdc_ft { @@ -44,8 +46,7 @@ void BaseCommand::Register(lyra::cli& cli) { std::function BaseCommand::JedecParser( const char* flag_name, uint64_t* bytes) { - return [flag_name, bytes, - error = &jedec_parse_error_](const std::string& value) { + return [flag_name, bytes, error = &parse_error_](const std::string& value) { JedecSize size; if (AbslParseFlag(value, &size, error)) { *bytes = size.Size(); @@ -56,6 +57,18 @@ std::function BaseCommand::JedecParser( }; } +std::function BaseCommand::PortRangeParser( + const char* flag_name, uint16_t* first, uint16_t* last) { + return [flag_name, first, last, + error = &parse_error_](const std::string& value) { + if (!port_range::Parse(value.c_str(), first, last)) { + *error = absl::StrFormat( + "Failed to parse %s=%s, expected or -", + flag_name, value); + } + }; +} + std::function BaseCommand::PosArgValidator( std::string* str) { return [str, invalid_arg = &invalid_arg_](const std::string& value) { @@ -83,8 +96,8 @@ void BaseCommand::CommandHandler(const lyra::group& g) { return; } - if (!jedec_parse_error_.empty()) { - std::cerr << "Error: " << jedec_parse_error_ << std::endl; + if (!parse_error_.empty()) { + std::cerr << "Error: " << parse_error_ << std::endl; *exit_code_ = 1; return; } diff --git a/cdc_stream/base_command.h b/cdc_stream/base_command.h index 65118cd..f82d824 100644 --- a/cdc_stream/base_command.h +++ b/cdc_stream/base_command.h @@ -48,6 +48,13 @@ class BaseCommand { std::function JedecParser(const char* flag_name, uint64_t* bytes); + // Parser for single ports "123" or port ranges "123-234". Usage: + // lyra::opt(PortRangeParser("port-flag", &first, &last), "port")) + // Automatically reports a parse failure on error. + std::function PortRangeParser(const char* flag_name, + uint16_t* first, + uint16_t* last); + // Validator that should be used for all positional arguments. Lyra interprets // -u, --unknown_flag as positional argument. This validator makes sure that // a positional argument starting with - is reported as an error. Otherwise, @@ -82,9 +89,9 @@ class BaseCommand { // Extraneous positional args. Gets reported as error if present. std::string extra_positional_arg_; - // Errors from parsing JEDEC sizes. + // Errors from custom flag parsers, e.g. JEDEC sizes or port ranges. // Works around Lyra not accepting errors from parsers. - std::string jedec_parse_error_; + std::string parse_error_; }; } // namespace cdc_ft diff --git a/cdc_stream/cdc_stream.vcxproj b/cdc_stream/cdc_stream.vcxproj index 4da2203..fd63d41 100644 --- a/cdc_stream/cdc_stream.vcxproj +++ b/cdc_stream/cdc_stream.vcxproj @@ -47,7 +47,7 @@ /std:c++17 - $(SolutionDir)bazel-out\x64_windows-opt\bin\asset_stcdc_streamream_manager\ + $(SolutionDir)bazel-out\x64_windows-opt\bin\cdc_stream\ UNICODE /std:c++17 diff --git a/cdc_stream/multi_session.cc b/cdc_stream/multi_session.cc index 95d1b7a..bf74c0b 100644 --- a/cdc_stream/multi_session.cc +++ b/cdc_stream/multi_session.cc @@ -34,11 +34,6 @@ namespace cdc_ft { namespace { -// Ports used by the asset streaming service for local port forwarding on -// workstation and gamelet. -constexpr int kAssetStreamPortFirst = 44433; -constexpr int kAssetStreamPortLast = 44442; - // Stats output period (if enabled). constexpr double kStatsPrintDelaySec = 0.1f; @@ -441,16 +436,19 @@ absl::Status MultiSession::Initialize() { } // Find an available local port. - std::unordered_set ports; - ASSIGN_OR_RETURN( - ports, - PortManager::FindAvailableLocalPorts(kAssetStreamPortFirst, - kAssetStreamPortLast, "127.0.0.1", - process_factory_), - "Failed to find an available local port in the range [%d, %d]", - kAssetStreamPortFirst, kAssetStreamPortLast); - assert(!ports.empty()); - local_asset_stream_port_ = *ports.begin(); + local_asset_stream_port_ = cfg_.forward_port_first; + if (cfg_.forward_port_first < cfg_.forward_port_last) { + std::unordered_set ports; + ASSIGN_OR_RETURN( + ports, + PortManager::FindAvailableLocalPorts(cfg_.forward_port_first, + cfg_.forward_port_last, + "127.0.0.1", process_factory_), + "Failed to find an available local port in the range [%d, %d]", + cfg_.forward_port_first, cfg_.forward_port_last); + assert(!ports.empty()); + local_asset_stream_port_ = *ports.begin(); + } assert(!runner_); runner_ = std::make_unique( @@ -526,7 +524,8 @@ absl::Status MultiSession::StartSession(const std::string& instance_id, auto session = std::make_unique( instance_id, target, cfg_, process_factory_, std::move(metrics_recorder)); RETURN_IF_ERROR(session->Start(local_asset_stream_port_, - kAssetStreamPortFirst, kAssetStreamPortLast)); + cfg_.forward_port_first, + cfg_.forward_port_last)); // Wait for the FUSE to receive the first intermediate manifest. RETURN_IF_ERROR(runner_->WaitForManifestAck(instance_id, absl::Seconds(5))); diff --git a/cdc_stream/multi_session.h b/cdc_stream/multi_session.h index 189bc26..df1429e 100644 --- a/cdc_stream/multi_session.h +++ b/cdc_stream/multi_session.h @@ -134,6 +134,11 @@ class MultiSessionRunner { // to an arbitrary number of gamelets. class MultiSession { public: + // Ports used by the asset streaming service for local port forwarding on + // workstation and gamelet. + static constexpr int kDefaultForwardPortFirst = 44433; + static constexpr int kDefaultForwardPortLast = 44442; + // Maximum length of cache path. We must be able to write content hashes into // this path: // \01234567890123456789 = 260 characters. diff --git a/cdc_stream/session.cc b/cdc_stream/session.cc index 835ab55..e6c0cd9 100644 --- a/cdc_stream/session.cc +++ b/cdc_stream/session.cc @@ -71,16 +71,19 @@ Session::~Session() { absl::Status Session::Start(int local_port, int first_remote_port, int last_remote_port) { // Find an available remote port. - std::unordered_set ports; - ASSIGN_OR_RETURN( - ports, - PortManager::FindAvailableRemotePorts( - first_remote_port, last_remote_port, "127.0.0.1", process_factory_, - &remote_util_, kInstanceConnectionTimeoutSec), - "Failed to find an available remote port in the range [%d, %d]", - first_remote_port, last_remote_port); - assert(!ports.empty()); - int remote_port = *ports.begin(); + int remote_port = first_remote_port; + if (first_remote_port < last_remote_port) { + std::unordered_set ports; + ASSIGN_OR_RETURN( + ports, + PortManager::FindAvailableRemotePorts( + first_remote_port, last_remote_port, "127.0.0.1", process_factory_, + &remote_util_, kInstanceConnectionTimeoutSec), + "Failed to find an available remote port in the range [%d, %d]", + first_remote_port, last_remote_port); + assert(!ports.empty()); + remote_port = *ports.begin(); + } assert(!fuse_); fuse_ = std::make_unique(instance_id_, process_factory_, diff --git a/cdc_stream/session_config.h b/cdc_stream/session_config.h index 89e36c3..36252ed 100644 --- a/cdc_stream/session_config.h +++ b/cdc_stream/session_config.h @@ -56,6 +56,10 @@ struct SessionConfig { // Time to wait until running a manifest update after detecting a file change. uint32_t file_change_wait_duration_ms = 0; + + // Ports used for local port forwarding. + uint16_t forward_port_first = 0; + uint16_t forward_port_last = 0; }; } // namespace cdc_ft diff --git a/cdc_stream/start_command.cc b/cdc_stream/start_command.cc index 393ccd3..0fdd19c 100644 --- a/cdc_stream/start_command.cc +++ b/cdc_stream/start_command.cc @@ -22,7 +22,6 @@ #include "common/log.h" #include "common/path.h" #include "common/process.h" -#include "common/remote_util.h" #include "common/status_macros.h" #include "common/stopwatch.h" #include "common/util.h" diff --git a/common/BUILD b/common/BUILD index d5c45c1..db40b20 100644 --- a/common/BUILD +++ b/common/BUILD @@ -254,6 +254,25 @@ cc_test( ], ) +cc_library( + name = "port_range_parser", + srcs = ["port_range_parser.cc"], + hdrs = ["port_range_parser.h"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "port_range_parser_test", + srcs = ["port_range_parser_test.cc"], + deps = [ + ":port_range_parser", + ":test_main", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "process", srcs = ["process_win.cc"], diff --git a/common/port_manager.h b/common/port_manager.h index 49e46d8..59ce4f4 100644 --- a/common/port_manager.h +++ b/common/port_manager.h @@ -51,14 +51,11 @@ class PortManager { // Reserves a port in the range passed to the constructor. The port is // released automatically upon destruction if ReleasePort() is not called // explicitly. - // |check_remote| determines whether the remote port should be checked as - // well. If false, the check is skipped and a port might be returned that is - // still in use remotely. // |remote_timeout_sec| is the timeout for finding available ports on the - // remote instance. Not used if |check_remote| is false. + // remote instance. // Returns a DeadlineExceeded error if the timeout is exceeded. // Returns a ResourceExhausted error if no ports are available. - absl::StatusOr ReservePort(bool check_remote, int remote_timeout_sec); + absl::StatusOr ReservePort(int remote_timeout_sec); // Releases a reserved port. absl::Status ReleasePort(int port); diff --git a/common/port_manager_test.cc b/common/port_manager_test.cc index 72451c9..be2c9f5 100644 --- a/common/port_manager_test.cc +++ b/common/port_manager_test.cc @@ -38,9 +38,6 @@ constexpr int kTimeoutSec = 1; constexpr char kLocalNetstat[] = "netstat -a -n -p tcp"; constexpr char kRemoteNetstat[] = "netstat --numeric --listening --tcp"; -constexpr bool kCheckRemote = true; -constexpr bool kNoCheckRemote = false; - constexpr char kLocalNetstatOutFmt[] = "TCP 127.0.0.1:50000 127.0.0.1:%i ESTABLISHED"; constexpr char kRemoteNetstatOutFmt[] = @@ -73,16 +70,7 @@ TEST_F(PortManagerTest, ReservePortSuccess) { process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); - absl::StatusOr port = - port_manager_.ReservePort(kCheckRemote, kTimeoutSec); - ASSERT_OK(port); - EXPECT_EQ(*port, kFirstPort); -} - -TEST_F(PortManagerTest, ReservePortNoRemoteSuccess) { - process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); - - absl::StatusOr port = port_manager_.ReservePort(kNoCheckRemote, 0); + absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); ASSERT_OK(port); EXPECT_EQ(*port, kFirstPort); } @@ -95,8 +83,7 @@ TEST_F(PortManagerTest, ReservePortAllLocalPortsTaken) { process_factory_.SetProcessOutput(kLocalNetstat, local_netstat_out, "", 0); process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); - absl::StatusOr port = - port_manager_.ReservePort(kCheckRemote, kTimeoutSec); + absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); EXPECT_TRUE(absl::IsResourceExhausted(port.status())); EXPECT_TRUE( absl::StrContains(port.status().message(), "No port available in range")); @@ -110,8 +97,7 @@ TEST_F(PortManagerTest, ReservePortAllRemotePortsTaken) { process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); process_factory_.SetProcessOutput(kRemoteNetstat, remote_netstat_out, "", 0); - absl::StatusOr port = - port_manager_.ReservePort(kCheckRemote, kTimeoutSec); + absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); EXPECT_TRUE(absl::IsResourceExhausted(port.status())); EXPECT_TRUE( absl::StrContains(port.status().message(), "No port available in range")); @@ -121,8 +107,7 @@ TEST_F(PortManagerTest, ReservePortLocalNetstatFails) { process_factory_.SetProcessOutput(kLocalNetstat, "", "", 1); process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); - absl::StatusOr port = - port_manager_.ReservePort(kCheckRemote, kTimeoutSec); + absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); EXPECT_NOT_OK(port); EXPECT_TRUE( absl::StrContains(port.status().message(), @@ -133,8 +118,7 @@ TEST_F(PortManagerTest, ReservePortRemoteNetstatFails) { process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 1); - absl::StatusOr port = - port_manager_.ReservePort(kCheckRemote, kTimeoutSec); + absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); EXPECT_NOT_OK(port); EXPECT_TRUE(absl::StrContains(port.status().message(), "Failed to find available ports on instance")); @@ -145,8 +129,7 @@ TEST_F(PortManagerTest, ReservePortRemoteNetstatTimesOut) { process_factory_.SetProcessNeverExits(kRemoteNetstat); steady_clock_.AutoAdvance(kTimeoutSec * 2 * 1000); - absl::StatusOr port = - port_manager_.ReservePort(kCheckRemote, kTimeoutSec); + absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); EXPECT_NOT_OK(port); EXPECT_TRUE(absl::IsDeadlineExceeded(port.status())); EXPECT_TRUE(absl::StrContains(port.status().message(), @@ -163,14 +146,10 @@ TEST_F(PortManagerTest, ReservePortMultipleInstances) { // Port managers use shared memory, so different instances know about each // other. This would even work if |port_manager_| and |port_manager2| belonged // to different processes, but we don't test that here. - EXPECT_EQ(*port_manager_.ReservePort(kCheckRemote, kTimeoutSec), - kFirstPort + 0); - EXPECT_EQ(*port_manager2.ReservePort(kCheckRemote, kTimeoutSec), - kFirstPort + 1); - EXPECT_EQ(*port_manager_.ReservePort(kCheckRemote, kTimeoutSec), - kFirstPort + 2); - EXPECT_EQ(*port_manager2.ReservePort(kCheckRemote, kTimeoutSec), - kFirstPort + 3); + EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec), kFirstPort + 0); + EXPECT_EQ(*port_manager2.ReservePort(kTimeoutSec), kFirstPort + 1); + EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec), kFirstPort + 2); + EXPECT_EQ(*port_manager2.ReservePort(kTimeoutSec), kFirstPort + 3); } TEST_F(PortManagerTest, ReservePortReusesPortsInLRUOrder) { @@ -178,7 +157,7 @@ TEST_F(PortManagerTest, ReservePortReusesPortsInLRUOrder) { process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); for (int n = 0; n < kNumPorts * 2; ++n) { - EXPECT_EQ(*port_manager_.ReservePort(kCheckRemote, kTimeoutSec), + EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec), kFirstPort + n % kNumPorts); system_clock_.Advance(1000); } @@ -188,11 +167,10 @@ TEST_F(PortManagerTest, ReleasePort) { process_factory_.SetProcessOutput(kLocalNetstat, "", "", 0); process_factory_.SetProcessOutput(kRemoteNetstat, "", "", 0); - absl::StatusOr port = - port_manager_.ReservePort(kCheckRemote, kTimeoutSec); + absl::StatusOr port = port_manager_.ReservePort(kTimeoutSec); EXPECT_EQ(*port, kFirstPort); EXPECT_OK(port_manager_.ReleasePort(*port)); - port = port_manager_.ReservePort(kCheckRemote, kTimeoutSec); + port = port_manager_.ReservePort(kTimeoutSec); EXPECT_EQ(*port, kFirstPort); } @@ -202,13 +180,10 @@ TEST_F(PortManagerTest, ReleasePortOnDestruction) { auto port_manager2 = std::make_unique( kGuid, kFirstPort, kLastPort, &process_factory_, &remote_util_); - EXPECT_EQ(*port_manager2->ReservePort(kCheckRemote, kTimeoutSec), - kFirstPort + 0); - EXPECT_EQ(*port_manager_.ReservePort(kCheckRemote, kTimeoutSec), - kFirstPort + 1); + EXPECT_EQ(*port_manager2->ReservePort(kTimeoutSec), kFirstPort + 0); + EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec), kFirstPort + 1); port_manager2.reset(); - EXPECT_EQ(*port_manager_.ReservePort(kCheckRemote, kTimeoutSec), - kFirstPort + 0); + EXPECT_EQ(*port_manager_.ReservePort(kTimeoutSec), kFirstPort + 0); } TEST_F(PortManagerTest, FindAvailableLocalPortsSuccess) { diff --git a/common/port_manager_win.cc b/common/port_manager_win.cc index 2c6c586..4e47755 100644 --- a/common/port_manager_win.cc +++ b/common/port_manager_win.cc @@ -121,8 +121,7 @@ PortManager::~PortManager() { } } -absl::StatusOr PortManager::ReservePort(bool check_remote, - int remote_timeout_sec) { +absl::StatusOr PortManager::ReservePort(int remote_timeout_sec) { // Find available port on workstation. std::unordered_set local_ports; ASSIGN_OR_RETURN(local_ports, @@ -132,13 +131,11 @@ absl::StatusOr PortManager::ReservePort(bool check_remote, // Find available port on remote instance. std::unordered_set remote_ports = local_ports; - if (check_remote) { - ASSIGN_OR_RETURN(remote_ports, - FindAvailableRemotePorts( - first_port_, last_port_, "0.0.0.0", process_factory_, - remote_util_, remote_timeout_sec, steady_clock_), - "Failed to find available ports on instance"); - } + ASSIGN_OR_RETURN(remote_ports, + FindAvailableRemotePorts(first_port_, last_port_, "0.0.0.0", + process_factory_, remote_util_, + remote_timeout_sec, steady_clock_), + "Failed to find available ports on instance"); // Fetch shared memory. void* mem; @@ -290,9 +287,14 @@ absl::StatusOr> PortManager::FindAvailablePorts( int first_port, int last_port, const std::string& netstat_output, const char* ip) { std::unordered_set available_ports; - for (int port = first_port; port <= last_port; ++port) { - std::vector lines = absl::StrSplit(netstat_output, '\n'); + std::vector lines; + for (const auto& line : absl::StrSplit(netstat_output, '\n')) { + if (absl::StrContains(line, ip)) { + lines.push_back(std::string(line)); + } + } + for (int port = first_port; port <= last_port; ++port) { bool port_occupied = false; std::string portToken = absl::StrFormat("%s:%i", ip, port); for (const std::string& line : lines) { diff --git a/common/port_range_parser.cc b/common/port_range_parser.cc new file mode 100644 index 0000000..2e5c34e --- /dev/null +++ b/common/port_range_parser.cc @@ -0,0 +1,40 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/port_range_parser.h" + +#include + +#include "absl/strings/str_split.h" + +namespace cdc_ft { +namespace port_range { + +bool Parse(const char* value, uint16_t* first, uint16_t* last) { + assert(value); + *first = 0; + *last = 0; + std::vector parts = absl::StrSplit(value, '-'); + if (parts.empty() || parts.size() > 2) return false; + const int ifirst = atoi(parts[0].c_str()); + const int ilast = parts.size() > 1 ? atoi(parts[1].c_str()) : ifirst; + if (ifirst <= 0 || ifirst > UINT16_MAX) return false; + if (ilast <= 0 || ilast > UINT16_MAX || ifirst > ilast) return false; + *first = static_cast(ifirst); + *last = static_cast(ilast); + return true; +} + +} // namespace port_range +} // namespace cdc_ft diff --git a/common/port_range_parser.h b/common/port_range_parser.h new file mode 100644 index 0000000..3e59fb1 --- /dev/null +++ b/common/port_range_parser.h @@ -0,0 +1,33 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_PORT_RANGE_PARSER_H_ +#define COMMON_PORT_RANGE_PARSER_H_ + +#include + +namespace cdc_ft { +namespace port_range { + +// Parses |value| into a port range |first|-|last|. +// If |value| is a single number a, assigns |first|=|last|=a. +// If |value| is a range a-b, assigns |first|=a, |last|=b. +bool Parse(const char* value, uint16_t* first, uint16_t* last); + +} // namespace port_range +} // namespace cdc_ft + +#endif // COMMON_PORT_RANGE_PARSER_H_ diff --git a/common/port_range_parser_test.cc b/common/port_range_parser_test.cc new file mode 100644 index 0000000..0a8c247 --- /dev/null +++ b/common/port_range_parser_test.cc @@ -0,0 +1,69 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/port_range_parser.h" + +#include "gtest/gtest.h" + +namespace cdc_ft { +namespace { + +TEST(PortRangeParserTest, SingleSuccess) { + uint16_t first, last; + EXPECT_TRUE(port_range::Parse("65535", &first, &last)); + EXPECT_EQ(first, 65535); + EXPECT_EQ(last, 65535); +} + +TEST(PortRangeParserTest, RangeSuccess) { + uint16_t first, last; + EXPECT_TRUE(port_range::Parse("1-2", &first, &last)); + EXPECT_EQ(first, 1); + EXPECT_EQ(last, 2); +} + +TEST(ParamsTest, NoValueFail) { + uint16_t first = 1, last = 1; + EXPECT_FALSE(port_range::Parse("", &first, &last)); + EXPECT_EQ(first, 0); + EXPECT_EQ(last, 0); +} + +TEST(ParamsTest, BadValueTooSmallFail) { + uint16_t first, last; + EXPECT_FALSE(port_range::Parse("0", &first, &last)); +} + +TEST(ParamsTest, BadValueNotIntegerFail) { + uint16_t first, last; + EXPECT_FALSE(port_range::Parse("port", &first, &last)); +} + +TEST(ParamsTest, ForwardPort_BadRangeTooBig) { + uint16_t first, last; + EXPECT_FALSE(port_range::Parse("50000-65536", &first, &last)); +} + +TEST(ParamsTest, ForwardPort_BadRangeFirstGtLast) { + uint16_t first, last; + EXPECT_FALSE(port_range::Parse("50001-50000", &first, &last)); +} + +TEST(ParamsTest, ForwardPort_BadRangeTwoMinus) { + uint16_t first, last; + EXPECT_FALSE(port_range::Parse("1-2-3", &first, &last)); +} + +} // namespace +} // namespace cdc_ft diff --git a/tests_common/BUILD b/tests_common/BUILD index 90d97a4..3905508 100644 --- a/tests_common/BUILD +++ b/tests_common/BUILD @@ -30,6 +30,7 @@ cc_binary( "//common:path_filter", "//common:platform", "//common:port_manager", + "//common:port_range_parser", "//common:process", "//common:remote_util", "//common:sdk_util",