Files
Lutz Justen 370023a944 [cdc_stream] Use ephemeral ports (#100)
Instead of running netstat/ss on local and remote systems, just bind
with port 0 to find an ephemeral port. This is much more robust,
simpler and a bit faster. Since the remote port is only known after
running cdc_fuse_fs, port forwarding has to be set up after running
cdc_fuse_fs.
2023-06-23 11:19:36 +02:00

267 lines
9.3 KiB
Python

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""cdc_stream test."""
import datetime
import logging
import os
import posixpath
import tempfile
import time
import subprocess
import unittest
from integration_tests.framework import utils
from integration_tests.framework import test_base
class CdcStreamTest(unittest.TestCase):
"""cdc_stream test class."""
# Grpc status codes.
NOT_FOUND = 5
SERVICE_UNAVAILABLE = 14
tmp_dir = None
local_base_dir = None
remote_base_dir = '/tmp/_cdc_stream_test/'
cache_dir = '~/.cache/cdc-file-transfer/chunks/'
service_port_arg = None
service_running = False
# Returns a list of files and directories with mtimes in the remote directory.
# For example, 2021-10-13 07:49:25.512766391 -0700 /tmp/_cdc_stream_test/2.txt
ls_cmd = ('find %s -exec stat --format \"%%y %%n\" \"{}\" \\;') % (
remote_base_dir)
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
logging.debug('CdcStreamTest -> setUpClass')
utils.initialize(None, test_base.Flags.binary_path,
test_base.Flags.user_host)
cls.service_port_arg = f'--service-port={test_base.Flags.service_port}'
cls._stop_service()
with tempfile.NamedTemporaryFile() as tf:
cls.config_path = tf.name
@classmethod
def tearDownClass(cls):
logging.debug('CdcStreamTest -> tearDownClass')
cls._stop_service()
if os.path.exists(cls.config_path):
os.remove(cls.config_path)
def setUp(self):
"""Stops the service, cleans up cache and streamed directory and initializes random."""
super(CdcStreamTest, self).setUp()
logging.debug('CdcStreamTest -> setUp')
now_str = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
self.tmp_dir = tempfile.TemporaryDirectory(
prefix=f'_cdc_stream_test_{now_str}')
self.local_base_dir = self.tmp_dir.name + '\\base\\'
utils.create_test_directory(self.local_base_dir)
logging.info('Local base dir: "%s"', self.local_base_dir)
logging.info('Remote base dir: "%s"', self.remote_base_dir)
utils.initialize_random()
self._stop(ignore_not_found=True)
self._clean_cache()
def tearDown(self):
super(CdcStreamTest, self).tearDown()
logging.debug('CdcStreamTest -> tearDown')
self.tmp_dir.cleanup()
@classmethod
def _start_service(cls, config_json=None):
"""Starts the asset streaming service.
Args:
config_json (string, optional): Config JSON string. Defaults to None.
"""
config_arg = None
if config_json:
with open(cls.config_path, 'wt') as file:
file.write(config_json)
config_arg = f'--config-file={cls.config_path}'
# Note: Service must be spawned in a background process.
args = ['start-service', config_arg, cls.service_port_arg]
command = [utils.CDC_STREAM_PATH, *filter(None, args)]
# Workaround issue with unicode logging.
logging.debug(
'Executing %s ',
' '.join(command).encode('utf-8').decode('ascii', 'backslashreplace'))
subprocess.Popen(command)
cls.service_running = True
@classmethod
def _stop_service(cls):
res = utils.run_stream('stop-service', cls.service_port_arg)
if res.returncode != 0:
logging.warn(f'Stopping service failed: {res}')
cls.service_running = False
def _start(self, local_dir=None):
"""Starts streaming the given directory
Args:
local_dir (string): Directory to stream. Defaults to local_base_dir.
"""
res = utils.run_stream('start', local_dir or self.local_base_dir,
utils.target(self.remote_base_dir),
self.service_port_arg)
self._assert_stream_success(res)
CdcStreamTest.service_running = True
def _stop(self, ignore_not_found=False):
"""Stops streaming to the target
Args:
ignore_not_found (bool): True to ignore if there's nothing to stop.
"""
if not CdcStreamTest.service_running:
return
res = utils.run_stream('stop', '*', self.service_port_arg)
if ignore_not_found and res.returncode == self.NOT_FOUND:
return
self._assert_stream_success(res)
def _assert_stream_success(self, res):
"""Asserts if the return code is 0 and outputs return message with args."""
self.assertEqual(res.returncode, 0, 'Return value is ' + str(res))
def _assert_remote_dir_matches(self, file_list):
"""Asserts that the remote directory matches the list of files and directories.
Args:
file_list (list of strings): List of relative paths to check.
"""
found = utils.get_sorted_files(self.remote_base_dir, '"*"')
expected = sorted(['./' + f for f in file_list])
self.assertListEqual(found, expected)
def _get_cache_size_in_bytes(self):
"""Returns the asset streaming cache size in bytes.
Returns:
bool: Cache size in bytes.
"""
result = utils.get_ssh_command_output('du -sb %s | awk \'{print $1}\'' %
self.cache_dir)
logging.info(f'Cache capacity is {int(result)}')
return int(result)
def _assert_cache(self):
"""Asserts that the asset streaming cache contains some data."""
cache_size = self._get_cache_size_in_bytes()
# On Linux, an empty directory occupies 4KB.
self.assertTrue(int(cache_size) >= 4096)
self.assertGreater(
int(utils.get_ssh_command_output('ls %s | wc -l' % self.cache_dir)), 0)
def _assert_cdc_fuse_mounted(self, success=True):
"""Asserts that CDC FUSE is appropriately mounted."""
logging.info(f'Asserting that FUSE is {"" if success else "not "}mounted')
result = utils.get_ssh_command_output('cat /etc/mtab | grep fuse')
if success:
self.assertIn(f'{self.remote_base_dir[:-1]} fuse.', result)
else:
self.assertNotIn(f'{self.remote_base_dir[:-1]} fuse.', result)
def _clean_cache(self):
"""Removes all data from the asset streaming caches."""
logging.info(f'Clearing cache')
utils.get_ssh_command_output('rm -rf %s' %
posixpath.join(self.cache_dir, '*'))
cache_dir = os.path.join(os.environ['APPDATA'], 'cdc-file-transfer',
'chunks')
utils.remove_test_directory(cache_dir)
def _create_test_data(self, files, dirs):
"""Create test data locally.
Args:
files (list of strings): List of relative file paths to create.
dirs (list of strings): List of relative dir paths to create.
"""
logging.info(
f'Creating test data with {len(files)} files and {len(dirs)} dirs')
for directory in dirs:
utils.create_test_directory(os.path.join(self.local_base_dir, directory))
for file in files:
utils.create_test_file(os.path.join(self.local_base_dir, file), 1024)
def _wait_until_remote_dir_changed(self, original, counter=20):
"""Wait until the directory content has changed.
Args:
original (string): The original file list of the remote directory.
counter (int): The number of retries.
Returns:
bool: Whether the content of the remote directory has changed.
"""
logging.info(f'Waiting until remote dir changes')
for _ in range(counter):
if utils.get_ssh_command_output(self.ls_cmd) != original:
return True
time.sleep(0.1)
logging.info(f'Still waiting...')
return False
def _test_dir_content(self, files, dirs, is_exe=False):
"""Check the streamed directory's content on remote instance.
Args:
files (list of strings): List of relative file paths to check.
dirs (list of strings): List of relative dir paths to check.
is_exe (bool): Flag which identifies whether files are executables.
"""
logging.info(
f'Testing dir content with {len(files)} files and {len(dirs)} dirs')
dirs = [directory.replace('\\', '/').rstrip('/') for directory in dirs]
files = [file.replace('\\', '/') for file in files]
# Read the content of the directory once to load some data.
utils.get_ssh_command_output('ls -al %s' % self.remote_base_dir)
self._assert_remote_dir_matches(files + dirs)
if not dirs and not files:
return
file_list = list()
mapping = dict()
for file in files:
full_name = posixpath.join(self.remote_base_dir, file)
self.assertTrue(
utils.sha1_matches(
os.path.join(self.local_base_dir, file), full_name))
file_list.append(full_name)
if is_exe:
mapping[full_name] = '-rwxr-xr-x'
else:
mapping[full_name] = '-rw-r--r--'
for directory in dirs:
full_name = posixpath.join(self.remote_base_dir, directory)
file_list.append(full_name)
mapping[full_name] = 'drwxr-xr-x'
ls_res = utils.get_ssh_command_output('ls -ld %s' % ' '.join(file_list))
for line in ls_res.splitlines():
self.assertIn(mapping[list(filter(None, line.split(' ')))[8]], line)