mirror of
https://github.com/nestriness/cdc-file-transfer.git
synced 2026-05-02 13:03:07 +03:00
[cdc_rsync] Add integration tests (#42)
[cdc_rsync] Add integration tests This CL adds Python integration tests for cdc_rsync. To run the tests, you need to supply a Linux host and proper configuration for cdc_rsync to work: set CDC_SSH_COMMAND=C:\path\to\ssh.exe <args> set CDC_SCP_COMMAND=C:\path\to\scp.exe <args> C:\python38\python.exe -m integration_tests.cdc_rsync.all_tests --binary_path=C:\full\path\to\cdc_rsync.exe --user_host=user@host Ran the tests and made sure they worked.
This commit is contained in:
13
integration_tests/framework/__init__.py
Normal file
13
integration_tests/framework/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
63
integration_tests/framework/test_base.py
Normal file
63
integration_tests/framework/test_base.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Test main and flags."""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import logging
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from integration_tests.framework import test_runner
|
||||
|
||||
|
||||
class Flags(object):
|
||||
binary_path = None
|
||||
user_host = None
|
||||
ssh_port = 22
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='End-to-end integration test.')
|
||||
parser.add_argument('--binary_path', help='Target [user@]host', required=True)
|
||||
parser.add_argument('--user_host', help='Target [user@]host', required=True)
|
||||
parser.add_argument(
|
||||
'--ssh_port',
|
||||
type=int,
|
||||
help='SSH port for connecting to the host',
|
||||
default=22)
|
||||
parser.add_argument('--log_file', help='Log file path')
|
||||
|
||||
# Capture all remaining arguments to pass to unittest.main().
|
||||
args, unittest_args = parser.parse_known_args()
|
||||
Flags.binary_path = args.binary_path
|
||||
Flags.user_host = args.user_host
|
||||
Flags.ssh_port = args.ssh_port
|
||||
|
||||
# Log to STDERR
|
||||
log_format = ('%(levelname)-8s%(asctime)s '
|
||||
'%(filename)s:%(lineno)-3d %(message)s')
|
||||
log_stream = sys.stderr
|
||||
|
||||
if args.log_file:
|
||||
log_stream = open(args.log_file, 'w')
|
||||
|
||||
with log_stream:
|
||||
logging.basicConfig(
|
||||
format=log_format, level=logging.DEBUG, stream=log_stream)
|
||||
|
||||
unittest.main(
|
||||
argv=sys.argv[:1] + unittest_args, testRunner=test_runner.TestRunner())
|
||||
66
integration_tests/framework/test_runner.py
Normal file
66
integration_tests/framework/test_runner.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Test runner, adds some sugar around logs to make them easier to read."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
|
||||
class TestRunner(object):
|
||||
"""Runner producing test xml output."""
|
||||
|
||||
def run(self, test): # pylint: disable=invalid-name
|
||||
result = TestResult()
|
||||
logging.info('Running tests...')
|
||||
test(result)
|
||||
logging.info('\n\n******************* TESTS FINISHED *******************\n')
|
||||
logging.info('Ran %d tests with %d errors and %d failures', result.testsRun,
|
||||
len(result.errors), len(result.failures))
|
||||
for test_and_stack in result.failures:
|
||||
logging.info('\n\n[ TEST FAILED ] %s\n', test_and_stack[0])
|
||||
logging.info(
|
||||
'%s', test_and_stack[1].replace('\\\\r',
|
||||
'\r').replace('\\\\n', '\n').replace(
|
||||
'\\r', '\r').replace('\\n', '\n'))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TestResult(unittest.TestResult):
|
||||
|
||||
def startTest(self, test):
|
||||
"""Called when the given test is about to be run."""
|
||||
logging.info('\n\n===== BEGIN TEST CASE: %s =====\n', test)
|
||||
unittest.TestResult.startTest(self, test)
|
||||
|
||||
def stopTest(self, test):
|
||||
"""Called when the given test has been run."""
|
||||
unittest.TestResult.stopTest(self, test)
|
||||
logging.info('\n\n===== END TEST CASE: %s =====\n', test)
|
||||
|
||||
def addError(self, test, err):
|
||||
unittest.TestResult.addError(self, test, err)
|
||||
self._LogFailureInfo(err)
|
||||
|
||||
def addFailure(self, test, err):
|
||||
unittest.TestResult.addFailure(self, test, err)
|
||||
self._LogFailureInfo(err)
|
||||
|
||||
def _LogFailureInfo(self, err):
|
||||
exctype, exc, tb = err
|
||||
detail = ''.join(traceback.format_exception(exctype, exc, tb))
|
||||
logging.error('FAILURE: %s', detail)
|
||||
331
integration_tests/framework/utils.py
Normal file
331
integration_tests/framework/utils.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Utils for file transfer tests."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import shutil
|
||||
import string
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
CDC_RSYNC_PATH = None
|
||||
USER_HOST = None
|
||||
|
||||
SHA1_LEN = 40
|
||||
SHA1_BUF_SIZE = 65536
|
||||
RANDOM = random.Random()
|
||||
|
||||
|
||||
def initialize(cdc_rsync_path, user_host):
|
||||
"""Sets global variables."""
|
||||
global CDC_RSYNC_PATH, USER_HOST
|
||||
|
||||
CDC_RSYNC_PATH = cdc_rsync_path
|
||||
USER_HOST = user_host
|
||||
|
||||
|
||||
def initialize_random():
|
||||
"""Sets random seed."""
|
||||
global RANDOM
|
||||
seed = int(time.time())
|
||||
logging.debug('Use random seed %i', seed)
|
||||
RANDOM.seed(seed)
|
||||
|
||||
|
||||
def _remove_carriage_return_lines(text):
|
||||
r"""Removes *\r, keeps only *\r\n lines.
|
||||
|
||||
Args:
|
||||
text (string): Text to remove lines from (usually cdc_rsync output).
|
||||
|
||||
Returns:
|
||||
string: Text with lines removed.
|
||||
"""
|
||||
|
||||
# Some lines have \r\r\n, treat them properly.
|
||||
ret = ''
|
||||
for line in text.replace('\r\r', '\r').split('\r\n'):
|
||||
ret += line.split('\r')[-1] + '\r\n'
|
||||
return ret
|
||||
|
||||
|
||||
def run_rsync(*args):
|
||||
"""Runs cdc_rsync with given args.
|
||||
|
||||
The last positional argument is assumed to be the destination. The user/host
|
||||
prefix [user@]host: is optional. If it does not have one, then it is prefixed
|
||||
by |USER_HOST|:.
|
||||
|
||||
Args:
|
||||
*args (string): cdc_rsync arguments.
|
||||
|
||||
Returns:
|
||||
CompletedProcess: cdc_rsync process info with exit code and stdout/stderr.
|
||||
"""
|
||||
|
||||
# Prefix last positional argument with [user@]host: if it doesn't have such
|
||||
# a prefix yet. Note that this won't work in all cases, e.g. if
|
||||
# '--exclude', 'file' is passed. Use '--exclude=file' instead.
|
||||
args_list = list(filter(None, args))
|
||||
for n in range(len(args_list) - 1, 0, -1):
|
||||
if args_list[n][0] != '-' and not ':' in args_list[n]:
|
||||
args_list[n] = USER_HOST + ":" + args_list[n]
|
||||
break
|
||||
|
||||
command = [CDC_RSYNC_PATH, *args_list]
|
||||
|
||||
# Workaround issue with unicode logging.
|
||||
logging.debug(
|
||||
'Executing %s ',
|
||||
' '.join(command).encode('utf-8').decode('ascii', 'backslashreplace'))
|
||||
res = subprocess.run(command, capture_output=True)
|
||||
# Remove lines ending with \r since those are temp display lines.
|
||||
res.stdout = _remove_carriage_return_lines(res.stdout.decode('ascii'))
|
||||
if res.stdout.strip():
|
||||
logging.debug('\r\n%s', res.stdout)
|
||||
return res
|
||||
|
||||
|
||||
def files_count_is(cdc_rsync_res,
|
||||
missing=0,
|
||||
missing_dir=0,
|
||||
changed=0,
|
||||
matching=0,
|
||||
matching_dir=0,
|
||||
extraneous=0,
|
||||
extraneous_dir=0):
|
||||
r"""Verifies that the output of cdc_rsync indicates the given file counts.
|
||||
|
||||
Args:
|
||||
cdc_rsync_res (CompletedProcess): Completed cdc_rsync process
|
||||
missing (int, optional): Number of missing files. Defaults to 0.
|
||||
missing_dir (int, optional): Number of missing folders. Defaults to 0.
|
||||
changed (int, optional): Number of changed files. Defaults to 0.
|
||||
matching (int, optional): Number of matching files. Defaults to 0.
|
||||
matching_dir (int, optional): Number of matching folders. Defaults to 0.
|
||||
extraneous (int, optional): Number of extraneous files. Defaults to 0.
|
||||
extraneous_dir (int, optional): Number of extraneous folders. \ Defaults
|
||||
to 0.
|
||||
|
||||
Returns:
|
||||
bool: True if all file counts match.
|
||||
"""
|
||||
missing_ok = '%i file(s) and %i folder(s) are not present' % (
|
||||
missing, missing_dir) in cdc_rsync_res.stdout
|
||||
changed_ok = '%i file(s) changed' % (changed) in cdc_rsync_res.stdout
|
||||
matching_ok = '%i file(s) and %i folder(s) match' % (
|
||||
matching, matching_dir) in cdc_rsync_res.stdout or """%i file(s) and %i \
|
||||
folder(s) have matching modified time and size""" % (
|
||||
matching, matching_dir) in cdc_rsync_res.stdout
|
||||
extraneous_ok = """%i file(s) and %i folder(s) on the instance do not exist \
|
||||
on this machine""" % (extraneous, extraneous_dir) in cdc_rsync_res.stdout
|
||||
return missing_ok and changed_ok and matching_ok and extraneous_ok
|
||||
|
||||
|
||||
def sha1sum_local(filepath):
|
||||
"""Computes the sha1 hash of a local file.
|
||||
|
||||
Args:
|
||||
filepath (string): Path of the local (Windows) file
|
||||
|
||||
Returns:
|
||||
string: sha1 hash
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
with open(filepath, 'rb') as f:
|
||||
while True:
|
||||
data = f.read(SHA1_BUF_SIZE)
|
||||
if not data:
|
||||
break
|
||||
sha1.update(data)
|
||||
return sha1.hexdigest()
|
||||
|
||||
|
||||
def sha1sum_remote(filepath):
|
||||
"""Computes the sha1 hash of a remote file.
|
||||
|
||||
Args:
|
||||
filepath (string): Path of the remote (Linux) file
|
||||
|
||||
Returns:
|
||||
string: sha1 hash
|
||||
"""
|
||||
return get_ssh_command_output('sha1sum %s' % filepath)[0:SHA1_LEN]
|
||||
|
||||
|
||||
def sha1_matches(local_path, remote_path):
|
||||
"""Compares the sha1 hashes of a local and a remote file.
|
||||
|
||||
Args:
|
||||
local_path (string): Path of the local (Windows) file
|
||||
remote_path (string): Path of the remote (Linux) file
|
||||
|
||||
Returns:
|
||||
bool: True if the sha1 hashes match
|
||||
"""
|
||||
|
||||
sha1_local = sha1sum_local(local_path)
|
||||
sha1_remote = sha1sum_remote(remote_path)
|
||||
return sha1_local == sha1_remote
|
||||
|
||||
|
||||
def create_test_file(local_path, size, printable_data=True, append=False):
|
||||
"""Creates a test file with random text of given size.
|
||||
|
||||
Args:
|
||||
local_path (string): Local path of the file to create.
|
||||
size (integer): Size of the file to create (bytes).
|
||||
printable_data (bool, optional): If the data should be printable. Writing
|
||||
a file with printable data is slower, for 1GB of data this takes ~5
|
||||
minutes, in comparison to ~2 seconds for non printable data. Defaults
|
||||
to True.
|
||||
append (bool, optional): If append mode should be used. Defaults to False.
|
||||
"""
|
||||
pathlib.Path(os.path.dirname(local_path)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mode = None
|
||||
random_bytes = None
|
||||
if printable_data:
|
||||
mode = 'at' if append else 'wt'
|
||||
random_bytes = ''.join(
|
||||
RANDOM.choices(string.ascii_uppercase + string.digits, k=size))
|
||||
else:
|
||||
mode = 'ab' if append else 'wb'
|
||||
random_bytes = os.urandom(size)
|
||||
|
||||
with open(local_path, mode) as f:
|
||||
if size > 0:
|
||||
f.write(random_bytes)
|
||||
|
||||
|
||||
def remove_test_file(local_path):
|
||||
"""Deletes a test file.
|
||||
|
||||
Args:
|
||||
local_path (string): Local path of the file to delete.
|
||||
"""
|
||||
os.remove(local_path)
|
||||
|
||||
|
||||
def create_test_directory(local_path):
|
||||
"""Creates a directory.
|
||||
|
||||
Args:
|
||||
local_path (string): Local path of the directory to create.
|
||||
"""
|
||||
pathlib.Path(os.path.dirname(local_path)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def remove_test_directory(local_path):
|
||||
"""Removes a directory with its content.
|
||||
|
||||
Args:
|
||||
local_path (string): Local path of the directory to remove.
|
||||
"""
|
||||
shutil.rmtree(pathlib.Path(os.path.dirname(local_path)), ignore_errors=True)
|
||||
|
||||
|
||||
def does_directory_exist_remotely(path):
|
||||
"""Checks if a directory exists on the remote instance.
|
||||
|
||||
Args:
|
||||
path (string): Path of the remote (Linux) directory
|
||||
|
||||
Returns:
|
||||
bool: True if a directory exists.
|
||||
"""
|
||||
return 'yes' in get_ssh_command_output('test -d %s && echo "yes"' % path)
|
||||
|
||||
|
||||
def does_file_exist_remotely(path):
|
||||
"""Checks if a file exists on the remote instance.
|
||||
|
||||
Args:
|
||||
path (string): Path of the remote (Linux) file
|
||||
|
||||
Returns:
|
||||
bool: True if a file exists.
|
||||
"""
|
||||
return 'yes' in get_ssh_command_output('test -f %s && echo "yes"' % path)
|
||||
|
||||
|
||||
def change_modified_time(path):
|
||||
"""Changes the modified time of the given file.
|
||||
|
||||
Args:
|
||||
path (string): Path of the local file
|
||||
"""
|
||||
stats = os.stat(path)
|
||||
os.utime(path, (stats.st_atime, stats.st_mtime + 1))
|
||||
|
||||
|
||||
def get_ssh_command_output(cmd):
|
||||
"""Runs an SSH command using the command from the CDC_SSH_COMMAND env var.
|
||||
|
||||
Args:
|
||||
cmd (string): Command that is being run remotely
|
||||
|
||||
Returns:
|
||||
string: The output of the ssh command.
|
||||
"""
|
||||
ssh_command = os.environ.get('CDC_SSH_COMMAND') or "ssh"
|
||||
full_ssh_cmd = '%s -tt "%s" -- %s' % (ssh_command, USER_HOST,
|
||||
quote_argument(cmd))
|
||||
res = subprocess.run(full_ssh_cmd, capture_output=True)
|
||||
if res.returncode != 0:
|
||||
logging.warning('SSH command %s failed with code %i, stderr: %s', cmd,
|
||||
res.returncode, res.stderr)
|
||||
return res.stdout.decode('ascii', errors='replace')
|
||||
|
||||
|
||||
def quote_argument(argument):
|
||||
# This isn't fully generic, but does the job... It doesn't handle when the
|
||||
# argument already escapes quotes, for instance.
|
||||
return '"' + argument.replace('"', '\\"') + '"'
|
||||
|
||||
|
||||
def get_sorted_files(remote_dir, pattern='"*.[t|d]*"'):
|
||||
"""Returns a sorted list of files in the remote_dir.
|
||||
|
||||
Args:
|
||||
remote_dir (string): Remote directory.
|
||||
pattern (string, optional): Pattern for matching file names.
|
||||
|
||||
Returns:
|
||||
string: Sorted list of files found in the remote directory.
|
||||
"""
|
||||
find_res = get_ssh_command_output('cd %s && find -name %s -print' %
|
||||
(remote_dir, pattern))
|
||||
|
||||
found = sorted(
|
||||
filter(lambda item: item and item != '.', find_res.split('\r\n')))
|
||||
return found
|
||||
|
||||
|
||||
def write_file(path, content):
|
||||
"""Writes a file and creates the parent directory if it does not exist yet.
|
||||
|
||||
Args:
|
||||
path (string): File path to create.
|
||||
content (string): File content.
|
||||
"""
|
||||
pathlib.Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True)
|
||||
with open(path, 'wt') as file:
|
||||
file.write(content)
|
||||
Reference in New Issue
Block a user