Files
netris-cdc-file-transfer/integration_tests/framework/utils.py
Lutz Justen 668c2ca8df [cdc_rsync] Add integration tests (#42)
[cdc_rsync] Add integration tests

This CL adds Python integration tests for cdc_rsync. To run the tests,
you need to supply a Linux host and proper configuration for cdc_rsync
to work:

  set CDC_SSH_COMMAND=C:\path\to\ssh.exe <args>
  set CDC_SCP_COMMAND=C:\path\to\scp.exe <args>
  C:\python38\python.exe -m integration_tests.cdc_rsync.all_tests --binary_path=C:\full\path\to\cdc_rsync.exe --user_host=user@host

Ran the tests and made sure they worked.
2022-12-08 08:39:43 +01:00

332 lines
9.6 KiB
Python

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Utils for file transfer tests."""
import hashlib
import logging
import os
import pathlib
import random
import shutil
import string
import subprocess
import time
CDC_RSYNC_PATH = None
USER_HOST = None
SHA1_LEN = 40
SHA1_BUF_SIZE = 65536
RANDOM = random.Random()
def initialize(cdc_rsync_path, user_host):
"""Sets global variables."""
global CDC_RSYNC_PATH, USER_HOST
CDC_RSYNC_PATH = cdc_rsync_path
USER_HOST = user_host
def initialize_random():
"""Sets random seed."""
global RANDOM
seed = int(time.time())
logging.debug('Use random seed %i', seed)
RANDOM.seed(seed)
def _remove_carriage_return_lines(text):
r"""Removes *\r, keeps only *\r\n lines.
Args:
text (string): Text to remove lines from (usually cdc_rsync output).
Returns:
string: Text with lines removed.
"""
# Some lines have \r\r\n, treat them properly.
ret = ''
for line in text.replace('\r\r', '\r').split('\r\n'):
ret += line.split('\r')[-1] + '\r\n'
return ret
def run_rsync(*args):
"""Runs cdc_rsync with given args.
The last positional argument is assumed to be the destination. The user/host
prefix [user@]host: is optional. If it does not have one, then it is prefixed
by |USER_HOST|:.
Args:
*args (string): cdc_rsync arguments.
Returns:
CompletedProcess: cdc_rsync process info with exit code and stdout/stderr.
"""
# Prefix last positional argument with [user@]host: if it doesn't have such
# a prefix yet. Note that this won't work in all cases, e.g. if
# '--exclude', 'file' is passed. Use '--exclude=file' instead.
args_list = list(filter(None, args))
for n in range(len(args_list) - 1, 0, -1):
if args_list[n][0] != '-' and not ':' in args_list[n]:
args_list[n] = USER_HOST + ":" + args_list[n]
break
command = [CDC_RSYNC_PATH, *args_list]
# Workaround issue with unicode logging.
logging.debug(
'Executing %s ',
' '.join(command).encode('utf-8').decode('ascii', 'backslashreplace'))
res = subprocess.run(command, capture_output=True)
# Remove lines ending with \r since those are temp display lines.
res.stdout = _remove_carriage_return_lines(res.stdout.decode('ascii'))
if res.stdout.strip():
logging.debug('\r\n%s', res.stdout)
return res
def files_count_is(cdc_rsync_res,
missing=0,
missing_dir=0,
changed=0,
matching=0,
matching_dir=0,
extraneous=0,
extraneous_dir=0):
r"""Verifies that the output of cdc_rsync indicates the given file counts.
Args:
cdc_rsync_res (CompletedProcess): Completed cdc_rsync process
missing (int, optional): Number of missing files. Defaults to 0.
missing_dir (int, optional): Number of missing folders. Defaults to 0.
changed (int, optional): Number of changed files. Defaults to 0.
matching (int, optional): Number of matching files. Defaults to 0.
matching_dir (int, optional): Number of matching folders. Defaults to 0.
extraneous (int, optional): Number of extraneous files. Defaults to 0.
extraneous_dir (int, optional): Number of extraneous folders. \ Defaults
to 0.
Returns:
bool: True if all file counts match.
"""
missing_ok = '%i file(s) and %i folder(s) are not present' % (
missing, missing_dir) in cdc_rsync_res.stdout
changed_ok = '%i file(s) changed' % (changed) in cdc_rsync_res.stdout
matching_ok = '%i file(s) and %i folder(s) match' % (
matching, matching_dir) in cdc_rsync_res.stdout or """%i file(s) and %i \
folder(s) have matching modified time and size""" % (
matching, matching_dir) in cdc_rsync_res.stdout
extraneous_ok = """%i file(s) and %i folder(s) on the instance do not exist \
on this machine""" % (extraneous, extraneous_dir) in cdc_rsync_res.stdout
return missing_ok and changed_ok and matching_ok and extraneous_ok
def sha1sum_local(filepath):
"""Computes the sha1 hash of a local file.
Args:
filepath (string): Path of the local (Windows) file
Returns:
string: sha1 hash
"""
sha1 = hashlib.sha1()
with open(filepath, 'rb') as f:
while True:
data = f.read(SHA1_BUF_SIZE)
if not data:
break
sha1.update(data)
return sha1.hexdigest()
def sha1sum_remote(filepath):
"""Computes the sha1 hash of a remote file.
Args:
filepath (string): Path of the remote (Linux) file
Returns:
string: sha1 hash
"""
return get_ssh_command_output('sha1sum %s' % filepath)[0:SHA1_LEN]
def sha1_matches(local_path, remote_path):
"""Compares the sha1 hashes of a local and a remote file.
Args:
local_path (string): Path of the local (Windows) file
remote_path (string): Path of the remote (Linux) file
Returns:
bool: True if the sha1 hashes match
"""
sha1_local = sha1sum_local(local_path)
sha1_remote = sha1sum_remote(remote_path)
return sha1_local == sha1_remote
def create_test_file(local_path, size, printable_data=True, append=False):
"""Creates a test file with random text of given size.
Args:
local_path (string): Local path of the file to create.
size (integer): Size of the file to create (bytes).
printable_data (bool, optional): If the data should be printable. Writing
a file with printable data is slower, for 1GB of data this takes ~5
minutes, in comparison to ~2 seconds for non printable data. Defaults
to True.
append (bool, optional): If append mode should be used. Defaults to False.
"""
pathlib.Path(os.path.dirname(local_path)).mkdir(parents=True, exist_ok=True)
mode = None
random_bytes = None
if printable_data:
mode = 'at' if append else 'wt'
random_bytes = ''.join(
RANDOM.choices(string.ascii_uppercase + string.digits, k=size))
else:
mode = 'ab' if append else 'wb'
random_bytes = os.urandom(size)
with open(local_path, mode) as f:
if size > 0:
f.write(random_bytes)
def remove_test_file(local_path):
"""Deletes a test file.
Args:
local_path (string): Local path of the file to delete.
"""
os.remove(local_path)
def create_test_directory(local_path):
"""Creates a directory.
Args:
local_path (string): Local path of the directory to create.
"""
pathlib.Path(os.path.dirname(local_path)).mkdir(parents=True, exist_ok=True)
def remove_test_directory(local_path):
"""Removes a directory with its content.
Args:
local_path (string): Local path of the directory to remove.
"""
shutil.rmtree(pathlib.Path(os.path.dirname(local_path)), ignore_errors=True)
def does_directory_exist_remotely(path):
"""Checks if a directory exists on the remote instance.
Args:
path (string): Path of the remote (Linux) directory
Returns:
bool: True if a directory exists.
"""
return 'yes' in get_ssh_command_output('test -d %s && echo "yes"' % path)
def does_file_exist_remotely(path):
"""Checks if a file exists on the remote instance.
Args:
path (string): Path of the remote (Linux) file
Returns:
bool: True if a file exists.
"""
return 'yes' in get_ssh_command_output('test -f %s && echo "yes"' % path)
def change_modified_time(path):
"""Changes the modified time of the given file.
Args:
path (string): Path of the local file
"""
stats = os.stat(path)
os.utime(path, (stats.st_atime, stats.st_mtime + 1))
def get_ssh_command_output(cmd):
"""Runs an SSH command using the command from the CDC_SSH_COMMAND env var.
Args:
cmd (string): Command that is being run remotely
Returns:
string: The output of the ssh command.
"""
ssh_command = os.environ.get('CDC_SSH_COMMAND') or "ssh"
full_ssh_cmd = '%s -tt "%s" -- %s' % (ssh_command, USER_HOST,
quote_argument(cmd))
res = subprocess.run(full_ssh_cmd, capture_output=True)
if res.returncode != 0:
logging.warning('SSH command %s failed with code %i, stderr: %s', cmd,
res.returncode, res.stderr)
return res.stdout.decode('ascii', errors='replace')
def quote_argument(argument):
# This isn't fully generic, but does the job... It doesn't handle when the
# argument already escapes quotes, for instance.
return '"' + argument.replace('"', '\\"') + '"'
def get_sorted_files(remote_dir, pattern='"*.[t|d]*"'):
"""Returns a sorted list of files in the remote_dir.
Args:
remote_dir (string): Remote directory.
pattern (string, optional): Pattern for matching file names.
Returns:
string: Sorted list of files found in the remote directory.
"""
find_res = get_ssh_command_output('cd %s && find -name %s -print' %
(remote_dir, pattern))
found = sorted(
filter(lambda item: item and item != '.', find_res.split('\r\n')))
return found
def write_file(path, content):
"""Writes a file and creates the parent directory if it does not exist yet.
Args:
path (string): File path to create.
content (string): File content.
"""
pathlib.Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True)
with open(path, 'wt') as file:
file.write(content)