Clean up code in ansible-test. (#73379)

* Relocate code to fix type dependencies.

* Fix missing and unused imports.

* Fix type hints.

* Suppress PyCharm false positives.

* Avoid shadowing `file` built-in.

* Use json.JSONEncoder directly instead of super().

This matches the recommended usage and avoids a PyCharm warning.

* Remove redundant regex escape.

* Remove redundant find_python call.

* Use tarfile.open directly.

* Add changelog fragment.
This commit is contained in:
Matt Clay 2021-01-26 14:02:08 -08:00 committed by GitHub
parent 76604397cb
commit 73fadc5e97
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 66 additions and 51 deletions

View file

@ -0,0 +1,2 @@
minor_changes:
- ansible-test - Cleaned up code to resolve warnings and errors reported by PyCharm.

View file

@ -112,6 +112,7 @@ def get_collection_version():
sys.modules['collection_detail'] = collection_detail sys.modules['collection_detail'] = collection_detail
collection_detail_spec.loader.exec_module(collection_detail) collection_detail_spec.loader.exec_module(collection_detail)
# noinspection PyBroadException
try: try:
result = collection_detail.read_manifest_json('.') or collection_detail.read_galaxy_yml('.') result = collection_detail.read_manifest_json('.') or collection_detail.read_galaxy_yml('.')
return SemanticVersion(result['version']) return SemanticVersion(result['version'])

View file

@ -81,7 +81,7 @@ class YamlChecker:
def check(self, paths): def check(self, paths):
""" """
:type paths: str :type paths: t.List[str]
""" """
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config') config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config')

View file

@ -4,7 +4,6 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import json import json
import os
try: try:
from sys import real_prefix from sys import real_prefix

View file

@ -181,12 +181,14 @@ class CryptographyAuthHelper(AuthHelper, ABC): # pylint: disable=abstract-metho
private_key = ec.generate_private_key(ec.SECP384R1(), default_backend()) private_key = ec.generate_private_key(ec.SECP384R1(), default_backend())
public_key = private_key.public_key() public_key = private_key.public_key()
# noinspection PyUnresolvedReferences
private_key_pem = to_text(private_key.private_bytes( private_key_pem = to_text(private_key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8, format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(), encryption_algorithm=serialization.NoEncryption(),
)) ))
# noinspection PyTypeChecker
public_key_pem = to_text(public_key.public_bytes( public_key_pem = to_text(public_key.public_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo, format=serialization.PublicFormat.SubjectPublicKeyInfo,

View file

@ -678,7 +678,7 @@ def key_value(argparse, value): # type: (argparse_module, str) -> t.Tuple[str,
return parts[0], parts[1] return parts[0], parts[1]
# noinspection PyProtectedMember # noinspection PyProtectedMember,PyUnresolvedReferences
def add_coverage_analyze(coverage_subparsers, coverage_common): # type: (argparse_module._SubParsersAction, argparse_module.ArgumentParser) -> None def add_coverage_analyze(coverage_subparsers, coverage_common): # type: (argparse_module._SubParsersAction, argparse_module.ArgumentParser) -> None
"""Add the `coverage analyze` subcommand.""" """Add the `coverage analyze` subcommand."""
analyze = coverage_subparsers.add_parser( analyze = coverage_subparsers.add_parser(

View file

@ -97,6 +97,7 @@ class AwsCloudEnvironment(CloudEnvironment):
resource_prefix=self.resource_prefix, resource_prefix=self.resource_prefix,
) )
# noinspection PyTypeChecker
ansible_vars.update(dict(parser.items('default'))) ansible_vars.update(dict(parser.items('default')))
display.sensitive.add(ansible_vars.get('aws_secret_key')) display.sensitive.add(ansible_vars.get('aws_secret_key'))

View file

@ -30,6 +30,7 @@ from .data import (
) )
try: try:
# noinspection PyTypeChecker
TIntegrationConfig = t.TypeVar('TIntegrationConfig', bound='IntegrationConfig') TIntegrationConfig = t.TypeVar('TIntegrationConfig', bound='IntegrationConfig')
except AttributeError: except AttributeError:
TIntegrationConfig = None # pylint: disable=invalid-name TIntegrationConfig = None # pylint: disable=invalid-name

View file

@ -163,8 +163,8 @@ def enumerate_python_arcs(
try: try:
original.read_file(path) original.read_file(path)
except Exception as ex: # pylint: disable=locally-disabled, broad-except except Exception as ex: # pylint: disable=locally-disabled, broad-except
with open_binary_file(path) as file: with open_binary_file(path) as file_obj:
header = file.read(6) header = file_obj.read(6)
if header == b'SQLite': if header == b'SQLite':
display.error('File created by "coverage" 5.0+: %s' % os.path.relpath(path)) display.error('File created by "coverage" 5.0+: %s' % os.path.relpath(path))

View file

@ -6,6 +6,8 @@ import json
import os import os
import time import time
from . import types as t
from .io import ( from .io import (
open_binary_file, open_binary_file,
read_text_file, read_text_file,

View file

@ -30,6 +30,7 @@ from .core_ci import (
from .manage_ci import ( from .manage_ci import (
ManageWindowsCI, ManageWindowsCI,
ManageNetworkCI, ManageNetworkCI,
get_network_settings,
) )
from .cloud import ( from .cloud import (
@ -73,7 +74,6 @@ from .util import (
from .util_common import ( from .util_common import (
get_docker_completion, get_docker_completion,
get_network_settings,
get_remote_completion, get_remote_completion,
get_python_path, get_python_path,
intercept_command, intercept_command,
@ -259,7 +259,6 @@ def get_cryptography_requirement(args, python, python_version): # type: (Enviro
Return the correct cryptography requirement for the given python version. Return the correct cryptography requirement for the given python version.
The version of cryptography installed depends on the python version, setuptools version and openssl version. The version of cryptography installed depends on the python version, setuptools version and openssl version.
""" """
python = find_python(python_version)
setuptools_version = get_setuptools_version(args, python) setuptools_version = get_setuptools_version(args, python)
openssl_version = get_openssl_version(args, python, python_version) openssl_version = get_openssl_version(args, python, python_version)
@ -624,7 +623,7 @@ def command_network_integration(args):
time.sleep(1) time.sleep(1)
remotes = [instance.wait_for_result() for instance in instances] remotes = [instance.wait_for_result() for instance in instances]
inventory = network_inventory(remotes) inventory = network_inventory(args, remotes)
display.info('>>> Inventory: %s\n%s' % (inventory_path, inventory.strip()), verbosity=3) display.info('>>> Inventory: %s\n%s' % (inventory_path, inventory.strip()), verbosity=3)
@ -702,14 +701,15 @@ def network_run(args, platform, version, config):
core_ci.load(config) core_ci.load(config)
core_ci.wait() core_ci.wait()
manage = ManageNetworkCI(core_ci) manage = ManageNetworkCI(args, core_ci)
manage.wait() manage.wait()
return core_ci return core_ci
def network_inventory(remotes): def network_inventory(args, remotes):
""" """
:type args: NetworkIntegrationConfig
:type remotes: list[AnsibleCoreCI] :type remotes: list[AnsibleCoreCI]
:rtype: str :rtype: str
""" """
@ -723,7 +723,7 @@ def network_inventory(remotes):
ansible_ssh_private_key_file=os.path.abspath(remote.ssh_key.key), ansible_ssh_private_key_file=os.path.abspath(remote.ssh_key.key),
) )
settings = get_network_settings(remote.args, remote.platform, remote.version) settings = get_network_settings(args, remote.platform, remote.version)
options.update(settings.inventory_vars) options.update(settings.inventory_vars)

View file

@ -28,8 +28,8 @@ def read_text_file(path): # type: (t.AnyStr) -> t.Text
def read_binary_file(path): # type: (t.AnyStr) -> bytes def read_binary_file(path): # type: (t.AnyStr) -> bytes
"""Return the contents of the specified path as bytes.""" """Return the contents of the specified path as bytes."""
with open_binary_file(path) as file: with open_binary_file(path) as file_obj:
return file.read() return file_obj.read()
def make_dirs(path): # type: (str) -> None def make_dirs(path): # type: (str) -> None
@ -63,8 +63,8 @@ def write_text_file(path, content, create_directories=False): # type: (str, str
if create_directories: if create_directories:
make_dirs(os.path.dirname(path)) make_dirs(os.path.dirname(path))
with open_binary_file(path, 'wb') as file: with open_binary_file(path, 'wb') as file_obj:
file.write(to_bytes(content)) file_obj.write(to_bytes(content))
def open_text_file(path, mode='r'): # type: (str, str) -> t.TextIO def open_text_file(path, mode='r'): # type: (str, str) -> t.TextIO
@ -91,4 +91,4 @@ class SortedSetEncoder(json.JSONEncoder):
if isinstance(obj, set): if isinstance(obj, set):
return sorted(obj) return sorted(obj)
return super(SortedSetEncoder).default(self, obj) return json.JSONEncoder.default(self, obj)

View file

@ -6,6 +6,8 @@ import os
import tempfile import tempfile
import time import time
from . import types as t
from .util import ( from .util import (
SubprocessError, SubprocessError,
ApplicationError, ApplicationError,
@ -16,7 +18,7 @@ from .util import (
from .util_common import ( from .util_common import (
intercept_command, intercept_command,
get_network_settings, get_network_completion,
run_command, run_command,
) )
@ -29,6 +31,7 @@ from .ansible_util import (
) )
from .config import ( from .config import (
NetworkIntegrationConfig,
ShellConfig, ShellConfig,
) )
@ -142,15 +145,17 @@ class ManageWindowsCI:
class ManageNetworkCI: class ManageNetworkCI:
"""Manage access to a network instance provided by Ansible Core CI.""" """Manage access to a network instance provided by Ansible Core CI."""
def __init__(self, core_ci): def __init__(self, args, core_ci):
""" """
:type args: NetworkIntegrationConfig
:type core_ci: AnsibleCoreCI :type core_ci: AnsibleCoreCI
""" """
self.args = args
self.core_ci = core_ci self.core_ci = core_ci
def wait(self): def wait(self):
"""Wait for instance to respond to ansible ping.""" """Wait for instance to respond to ansible ping."""
settings = get_network_settings(self.core_ci.args, self.core_ci.platform, self.core_ci.version) settings = get_network_settings(self.args, self.core_ci.platform, self.core_ci.version)
extra_vars = [ extra_vars = [
'ansible_host=%s' % self.core_ci.connection.hostname, 'ansible_host=%s' % self.core_ci.connection.hostname,
@ -333,3 +338,27 @@ class ManagePosixCI:
time.sleep(10) time.sleep(10)
raise ApplicationError('Failed transfer: %s -> %s' % (src, dst)) raise ApplicationError('Failed transfer: %s -> %s' % (src, dst))
def get_network_settings(args, platform, version): # type: (NetworkIntegrationConfig, str, str) -> NetworkPlatformSettings
"""Returns settings for the given network platform and version."""
platform_version = '%s/%s' % (platform, version)
completion = get_network_completion().get(platform_version, {})
collection = args.platform_collection.get(platform, completion.get('collection'))
settings = NetworkPlatformSettings(
collection,
dict(
ansible_connection=args.platform_connection.get(platform, completion.get('connection')),
ansible_network_os='%s.%s' % (collection, platform) if collection else platform,
)
)
return settings
class NetworkPlatformSettings:
"""Settings required for provisioning a network platform."""
def __init__(self, collection, inventory_vars): # type: (str, t.Type[str, str]) -> None
self.collection = collection
self.inventory_vars = inventory_vars

View file

@ -120,7 +120,7 @@ def create_payload(args, dst_path): # type: (CommonConfig, str) -> None
start = time.time() start = time.time()
with tarfile.TarFile.open(dst_path, mode='w:gz', compresslevel=4, format=tarfile.GNU_FORMAT) as tar: with tarfile.open(dst_path, mode='w:gz', compresslevel=4, format=tarfile.GNU_FORMAT) as tar:
for src, dst in files: for src, dst in files:
display.info('%s -> %s' % (src, dst), verbosity=4) display.info('%s -> %s' % (src, dst), verbosity=4)
tar.add(src, dst, filter=filters.get(dst)) tar.add(src, dst, filter=filters.get(dst))

View file

@ -15,6 +15,7 @@ from ..util import (
try: try:
# noinspection PyTypeChecker
TPathProvider = t.TypeVar('TPathProvider', bound='PathProvider') TPathProvider = t.TypeVar('TPathProvider', bound='PathProvider')
except AttributeError: except AttributeError:
TPathProvider = None # pylint: disable=invalid-name TPathProvider = None # pylint: disable=invalid-name

View file

@ -200,7 +200,7 @@ class LayoutProvider(PathProvider):
"""Create a layout using the given root and paths.""" """Create a layout using the given root and paths."""
def paths_to_tree(paths): # type: (t.List[str]) -> t.Tuple(t.Dict[str, t.Any], t.List[str]) def paths_to_tree(paths): # type: (t.List[str]) -> t.Tuple[t.Dict[str, t.Any], t.List[str]]
"""Return a filesystem tree from the given list of paths.""" """Return a filesystem tree from the given list of paths."""
tree = {}, [] tree = {}, []
@ -219,7 +219,7 @@ def paths_to_tree(paths): # type: (t.List[str]) -> t.Tuple(t.Dict[str, t.Any],
return tree return tree
def get_tree_item(tree, parts): # type: (t.Tuple(t.Dict[str, t.Any], t.List[str]), t.List[str]) -> t.Optional[t.Tuple(t.Dict[str, t.Any], t.List[str])] def get_tree_item(tree, parts): # type: (t.Tuple[t.Dict[str, t.Any], t.List[str]], t.List[str]) -> t.Optional[t.Tuple[t.Dict[str, t.Any], t.List[str]]]
"""Return the portion of the tree found under the path given by parts, or None if it does not exist.""" """Return the portion of the tree found under the path given by parts, or None if it does not exist."""
root = tree root = tree

View file

@ -127,7 +127,7 @@ class AnsibleDocTest(SanitySingleVersion):
if stderr: if stderr:
# ignore removed module/plugin warnings # ignore removed module/plugin warnings
stderr = re.sub(r'\[WARNING\]: [^ ]+ [^ ]+ has been removed\n', '', stderr).strip() stderr = re.sub(r'\[WARNING]: [^ ]+ [^ ]+ has been removed\n', '', stderr).strip()
if stderr: if stderr:
summary = u'Output on stderr from ansible-doc is considered an error.\n\n%s' % SubprocessError(cmd, stderr=stderr) summary = u'Output on stderr from ansible-doc is considered an error.\n\n%s' % SubprocessError(cmd, stderr=stderr)

View file

@ -33,11 +33,13 @@ from .data import (
MODULE_EXTENSIONS = '.py', '.ps1' MODULE_EXTENSIONS = '.py', '.ps1'
try: try:
# noinspection PyTypeChecker
TCompletionTarget = t.TypeVar('TCompletionTarget', bound='CompletionTarget') TCompletionTarget = t.TypeVar('TCompletionTarget', bound='CompletionTarget')
except AttributeError: except AttributeError:
TCompletionTarget = None # pylint: disable=invalid-name TCompletionTarget = None # pylint: disable=invalid-name
try: try:
# noinspection PyTypeChecker
TIntegrationTarget = t.TypeVar('TIntegrationTarget', bound='IntegrationTarget') TIntegrationTarget = t.TypeVar('TIntegrationTarget', bound='IntegrationTarget')
except AttributeError: except AttributeError:
TIntegrationTarget = None # pylint: disable=invalid-name TIntegrationTarget = None # pylint: disable=invalid-name

View file

@ -157,6 +157,7 @@ class TestResult:
try: try:
to_xml_string = self.junit.to_xml_report_string to_xml_string = self.junit.to_xml_report_string
except AttributeError: except AttributeError:
# noinspection PyDeprecation
to_xml_string = self.junit.TestSuite.to_xml_string to_xml_string = self.junit.TestSuite.to_xml_string
report = to_xml_string(test_suites=test_suites, prettyprint=True, encoding='utf-8') report = to_xml_string(test_suites=test_suites, prettyprint=True, encoding='utf-8')

View file

@ -615,7 +615,7 @@ class Display:
""" """
:type message: str :type message: str
:type color: str | None :type color: str | None
:type fd: file :type fd: t.IO[str]
:type truncate: bool :type truncate: bool
""" """
if self.redact and self.sensitive: if self.redact and self.sensitive:
@ -815,13 +815,11 @@ def load_module(path, name): # type: (str, str) -> None
return return
if sys.version_info >= (3, 4): if sys.version_info >= (3, 4):
# noinspection PyUnresolvedReferences
import importlib.util import importlib.util
# noinspection PyUnresolvedReferences
spec = importlib.util.spec_from_file_location(name, path) spec = importlib.util.spec_from_file_location(name, path)
# noinspection PyUnresolvedReferences
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
# noinspection PyUnresolvedReferences
spec.loader.exec_module(module) spec.loader.exec_module(module)
sys.modules[name] = module sys.modules[name] = module

View file

@ -115,13 +115,6 @@ class CommonConfig:
return os.path.join(ANSIBLE_TEST_DATA_ROOT, 'ansible.cfg') return os.path.join(ANSIBLE_TEST_DATA_ROOT, 'ansible.cfg')
class NetworkPlatformSettings:
"""Settings required for provisioning a network platform."""
def __init__(self, collection, inventory_vars): # type: (str, t.Type[str, str]) -> None
self.collection = collection
self.inventory_vars = inventory_vars
def get_docker_completion(): def get_docker_completion():
""" """
:rtype: dict[str, dict[str, str]] :rtype: dict[str, dict[str, str]]
@ -185,23 +178,6 @@ def docker_qualify_image(name):
return config.get('name', name) return config.get('name', name)
def get_network_settings(args, platform, version): # type: (NetworkIntegrationConfig, str, str) -> NetworkPlatformSettings
"""Returns settings for the given network platform and version."""
platform_version = '%s/%s' % (platform, version)
completion = get_network_completion().get(platform_version, {})
collection = args.platform_collection.get(platform, completion.get('collection'))
settings = NetworkPlatformSettings(
collection,
dict(
ansible_connection=args.platform_connection.get(platform, completion.get('connection')),
ansible_network_os='%s.%s' % (collection, platform) if collection else platform,
)
)
return settings
def handle_layout_messages(messages): # type: (t.Optional[LayoutMessages]) -> None def handle_layout_messages(messages): # type: (t.Optional[LayoutMessages]) -> None
"""Display the given layout messages.""" """Display the given layout messages."""
if not messages: if not messages: