From b6811dfb61bee06dad08e90ef541667be7bbc950 Mon Sep 17 00:00:00 2001 From: Sam Doran Date: Thu, 11 Feb 2021 19:17:14 -0500 Subject: [PATCH] Add argument spec validator (#73335) Add argument spec validator class --- .../73335-argument-spec_validator.yml | 4 + lib/ansible/module_utils/basic.py | 343 +--------- lib/ansible/module_utils/common/arg_spec.py | 134 ++++ lib/ansible/module_utils/common/parameters.py | 606 +++++++++++++++++- lib/ansible/module_utils/common/validation.py | 67 +- .../module_utils/common/arg_spec/__init__.py | 0 .../common/arg_spec/test_sub_spec.py | 106 +++ .../common/arg_spec/test_validate_aliases.py | 61 ++ .../common/arg_spec/test_validate_basic.py | 100 +++ .../common/arg_spec/test_validate_failures.py | 29 + 10 files changed, 1083 insertions(+), 367 deletions(-) create mode 100644 changelogs/fragments/73335-argument-spec_validator.yml create mode 100644 lib/ansible/module_utils/common/arg_spec.py create mode 100644 test/units/module_utils/common/arg_spec/__init__.py create mode 100644 test/units/module_utils/common/arg_spec/test_sub_spec.py create mode 100644 test/units/module_utils/common/arg_spec/test_validate_aliases.py create mode 100644 test/units/module_utils/common/arg_spec/test_validate_basic.py create mode 100644 test/units/module_utils/common/arg_spec/test_validate_failures.py diff --git a/changelogs/fragments/73335-argument-spec_validator.yml b/changelogs/fragments/73335-argument-spec_validator.yml new file mode 100644 index 00000000000..b7669405c2a --- /dev/null +++ b/changelogs/fragments/73335-argument-spec_validator.yml @@ -0,0 +1,4 @@ +major_changes: + - >- + add ``ArgumentSpecValidator`` class for validating parameters against an + argument spec outside of ``AnsibleModule`` (https://github.com/ansible/ansible/pull/73335) diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index d34ea5d7dc1..ad608415d8c 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -55,7 +55,6 @@ import time import traceback import types -from collections import deque from itertools import chain, repeat try: @@ -156,11 +155,20 @@ from ansible.module_utils.common.sys_info import ( ) from ansible.module_utils.pycompat24 import get_exception, literal_eval from ansible.module_utils.common.parameters import ( + _remove_values_conditions, + _sanitize_keys_conditions, + sanitize_keys, + env_fallback, get_unsupported_parameters, get_type_validator, handle_aliases, list_deprecations, list_no_log_values, + remove_values, + set_defaults, + set_fallbacks, + validate_argument_types, + AnsibleFallbackNotFound, DEFAULT_TYPE_VALIDATORS, PASS_VARS, PASS_BOOLS, @@ -241,14 +249,6 @@ _literal_eval = literal_eval _ANSIBLE_ARGS = None -def env_fallback(*args, **kwargs): - ''' Load value from environment ''' - for arg in args: - if arg in os.environ: - return os.environ[arg] - raise AnsibleFallbackNotFound - - FILE_COMMON_ARGUMENTS = dict( # These are things we want. About setting metadata (mode, ownership, permissions in general) on # created files (these are used by set_fs_attributes_if_different and included in @@ -320,212 +320,6 @@ def get_all_subclasses(cls): # End compat shims -def _remove_values_conditions(value, no_log_strings, deferred_removals): - """ - Helper function for :meth:`remove_values`. - - :arg value: The value to check for strings that need to be stripped - :arg no_log_strings: set of strings which must be stripped out of any values - :arg deferred_removals: List which holds information about nested - containers that have to be iterated for removals. It is passed into - this function so that more entries can be added to it if value is - a container type. The format of each entry is a 2-tuple where the first - element is the ``value`` parameter and the second value is a new - container to copy the elements of ``value`` into once iterated. - :returns: if ``value`` is a scalar, returns ``value`` with two exceptions: - 1. :class:`~datetime.datetime` objects which are changed into a string representation. - 2. objects which are in no_log_strings are replaced with a placeholder - so that no sensitive data is leaked. - If ``value`` is a container type, returns a new empty container. - - ``deferred_removals`` is added to as a side-effect of this function. - - .. warning:: It is up to the caller to make sure the order in which value - is passed in is correct. For instance, higher level containers need - to be passed in before lower level containers. For example, given - ``{'level1': {'level2': 'level3': [True]} }`` first pass in the - dictionary for ``level1``, then the dict for ``level2``, and finally - the list for ``level3``. - """ - if isinstance(value, (text_type, binary_type)): - # Need native str type - native_str_value = value - if isinstance(value, text_type): - value_is_text = True - if PY2: - native_str_value = to_bytes(value, errors='surrogate_or_strict') - elif isinstance(value, binary_type): - value_is_text = False - if PY3: - native_str_value = to_text(value, errors='surrogate_or_strict') - - if native_str_value in no_log_strings: - return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' - for omit_me in no_log_strings: - native_str_value = native_str_value.replace(omit_me, '*' * 8) - - if value_is_text and isinstance(native_str_value, binary_type): - value = to_text(native_str_value, encoding='utf-8', errors='surrogate_then_replace') - elif not value_is_text and isinstance(native_str_value, text_type): - value = to_bytes(native_str_value, encoding='utf-8', errors='surrogate_then_replace') - else: - value = native_str_value - - elif isinstance(value, Sequence): - if isinstance(value, MutableSequence): - new_value = type(value)() - else: - new_value = [] # Need a mutable value - deferred_removals.append((value, new_value)) - value = new_value - - elif isinstance(value, Set): - if isinstance(value, MutableSet): - new_value = type(value)() - else: - new_value = set() # Need a mutable value - deferred_removals.append((value, new_value)) - value = new_value - - elif isinstance(value, Mapping): - if isinstance(value, MutableMapping): - new_value = type(value)() - else: - new_value = {} # Need a mutable value - deferred_removals.append((value, new_value)) - value = new_value - - elif isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))): - stringy_value = to_native(value, encoding='utf-8', errors='surrogate_or_strict') - if stringy_value in no_log_strings: - return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' - for omit_me in no_log_strings: - if omit_me in stringy_value: - return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' - - elif isinstance(value, (datetime.datetime, datetime.date)): - value = value.isoformat() - else: - raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) - - return value - - -def remove_values(value, no_log_strings): - """ Remove strings in no_log_strings from value. If value is a container - type, then remove a lot more. - - Use of deferred_removals exists, rather than a pure recursive solution, - because of the potential to hit the maximum recursion depth when dealing with - large amounts of data (see issue #24560). - """ - - deferred_removals = deque() - - no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings] - new_value = _remove_values_conditions(value, no_log_strings, deferred_removals) - - while deferred_removals: - old_data, new_data = deferred_removals.popleft() - if isinstance(new_data, Mapping): - for old_key, old_elem in old_data.items(): - new_elem = _remove_values_conditions(old_elem, no_log_strings, deferred_removals) - new_data[old_key] = new_elem - else: - for elem in old_data: - new_elem = _remove_values_conditions(elem, no_log_strings, deferred_removals) - if isinstance(new_data, MutableSequence): - new_data.append(new_elem) - elif isinstance(new_data, MutableSet): - new_data.add(new_elem) - else: - raise TypeError('Unknown container type encountered when removing private values from output') - - return new_value - - -def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_removals): - """ Helper method to sanitize_keys() to build deferred_removals and avoid deep recursion. """ - if isinstance(value, (text_type, binary_type)): - return value - - if isinstance(value, Sequence): - if isinstance(value, MutableSequence): - new_value = type(value)() - else: - new_value = [] # Need a mutable value - deferred_removals.append((value, new_value)) - return new_value - - if isinstance(value, Set): - if isinstance(value, MutableSet): - new_value = type(value)() - else: - new_value = set() # Need a mutable value - deferred_removals.append((value, new_value)) - return new_value - - if isinstance(value, Mapping): - if isinstance(value, MutableMapping): - new_value = type(value)() - else: - new_value = {} # Need a mutable value - deferred_removals.append((value, new_value)) - return new_value - - if isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))): - return value - - if isinstance(value, (datetime.datetime, datetime.date)): - return value - - raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) - - -def sanitize_keys(obj, no_log_strings, ignore_keys=frozenset()): - """ Sanitize the keys in a container object by removing no_log values from key names. - - This is a companion function to the `remove_values()` function. Similar to that function, - we make use of deferred_removals to avoid hitting maximum recursion depth in cases of - large data structures. - - :param obj: The container object to sanitize. Non-container objects are returned unmodified. - :param no_log_strings: A set of string values we do not want logged. - :param ignore_keys: A set of string values of keys to not sanitize. - - :returns: An object with sanitized keys. - """ - - deferred_removals = deque() - - no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings] - new_value = _sanitize_keys_conditions(obj, no_log_strings, ignore_keys, deferred_removals) - - while deferred_removals: - old_data, new_data = deferred_removals.popleft() - - if isinstance(new_data, Mapping): - for old_key, old_elem in old_data.items(): - if old_key in ignore_keys or old_key.startswith('_ansible'): - new_data[old_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals) - else: - # Sanitize the old key. We take advantage of the sanitizing code in - # _remove_values_conditions() rather than recreating it here. - new_key = _remove_values_conditions(old_key, no_log_strings, None) - new_data[new_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals) - else: - for elem in old_data: - new_elem = _sanitize_keys_conditions(elem, no_log_strings, ignore_keys, deferred_removals) - if isinstance(new_data, MutableSequence): - new_data.append(new_elem) - elif isinstance(new_data, MutableSet): - new_data.add(new_elem) - else: - raise TypeError('Unknown container type encountered when removing private values from keys') - - return new_value - - def heuristic_log_sanitize(data, no_log_values=None): ''' Remove strings that look like passwords from log messages ''' # Currently filters: @@ -661,10 +455,6 @@ def missing_required_lib(library, reason=None, url=None): return msg -class AnsibleFallbackNotFound(Exception): - pass - - class AnsibleModule(object): def __init__(self, argument_spec, bypass_checks=False, no_log=False, mutually_exclusive=None, required_together=None, @@ -1492,21 +1282,16 @@ class AnsibleModule(object): # this uses exceptions as it happens before we can safely call fail_json alias_warnings = [] - alias_results, self._legal_inputs = handle_aliases(spec, param, alias_warnings=alias_warnings) + alias_deprecations = [] + alias_results, self._legal_inputs = handle_aliases(spec, param, alias_warnings, alias_deprecations) for option, alias in alias_warnings: warn('Both option %s and its alias %s are set.' % (option_prefix + option, option_prefix + alias)) - deprecated_aliases = [] - for i in spec.keys(): - if 'deprecated_aliases' in spec[i].keys(): - for alias in spec[i]['deprecated_aliases']: - deprecated_aliases.append(alias) + for deprecation in alias_deprecations: + deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'], + version=deprecation.get('version'), date=deprecation.get('date'), + collection_name=deprecation.get('collection_name')) - for deprecation in deprecated_aliases: - if deprecation['name'] in param.keys(): - deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'], - version=deprecation.get('version'), date=deprecation.get('date'), - collection_name=deprecation.get('collection_name')) return alias_results def _handle_no_log_values(self, spec=None, param=None): @@ -1818,7 +1603,6 @@ class AnsibleModule(object): options_legal_inputs = list(spec.keys()) + list(options_aliases.keys()) - self._set_internal_properties(spec, param) self._check_arguments(spec, param, options_legal_inputs) # check exclusive early @@ -1854,28 +1638,6 @@ class AnsibleModule(object): return type_checker, wanted - def _handle_elements(self, wanted, param, values): - type_checker, wanted_name = self._get_wanted_type(wanted, param) - validated_params = [] - # Get param name for strings so we can later display this value in a useful error message if needed - # Only pass 'kwargs' to our checkers and ignore custom callable checkers - kwargs = {} - if wanted_name == 'str' and isinstance(wanted, string_types): - if isinstance(param, string_types): - kwargs['param'] = param - elif isinstance(param, dict): - kwargs['param'] = list(param.keys())[0] - for value in values: - try: - validated_params.append(type_checker(value, **kwargs)) - except (TypeError, ValueError) as e: - msg = "Elements value for option %s" % param - if self._options_context: - msg += " found in '%s'" % " -> ".join(self._options_context) - msg += " is of type %s and we were unable to convert to %s: %s" % (type(value), wanted_name, to_native(e)) - self.fail_json(msg=msg) - return validated_params - def _check_argument_types(self, spec=None, param=None, prefix=''): ''' ensure all arguments have the requested type ''' @@ -1884,61 +1646,22 @@ class AnsibleModule(object): if param is None: param = self.params - for (k, v) in spec.items(): - wanted = v.get('type', None) - if k not in param: - continue + errors = [] + validate_argument_types(spec, param, errors=errors) - value = param[k] - if value is None: - continue - - type_checker, wanted_name = self._get_wanted_type(wanted, k) - # Get param name for strings so we can later display this value in a useful error message if needed - # Only pass 'kwargs' to our checkers and ignore custom callable checkers - kwargs = {} - if wanted_name == 'str' and isinstance(type_checker, string_types): - kwargs['param'] = list(param.keys())[0] - - # Get the name of the parent key if this is a nested option - if prefix: - kwargs['prefix'] = prefix - - try: - param[k] = type_checker(value, **kwargs) - wanted_elements = v.get('elements', None) - if wanted_elements: - if wanted != 'list' or not isinstance(param[k], list): - msg = "Invalid type %s for option '%s'" % (wanted_name, param) - if self._options_context: - msg += " found in '%s'." % " -> ".join(self._options_context) - msg += ", elements value check is supported only with 'list' type" - self.fail_json(msg=msg) - param[k] = self._handle_elements(wanted_elements, k, param[k]) - - except (TypeError, ValueError) as e: - msg = "argument %s is of type %s" % (k, type(value)) - if self._options_context: - msg += " found in '%s'." % " -> ".join(self._options_context) - msg += " and we were unable to convert to %s: %s" % (wanted_name, to_native(e)) - self.fail_json(msg=msg) + if errors: + self.fail_json(msg=errors[0]) def _set_defaults(self, pre=True, spec=None, param=None): if spec is None: spec = self.argument_spec if param is None: param = self.params - for (k, v) in spec.items(): - default = v.get('default', None) - # This prevents setting defaults on required items on the 1st run, - # otherwise will set things without a default to None on the 2nd. - if k not in param and (default is not None or not pre): - # Make sure any default value for no_log fields are masked. - if v.get('no_log', False) and default: - self.no_log_values.add(default) - - param[k] = default + # The interface for set_defaults is different than _set_defaults() + # The third parameter controls whether or not defaults are actually set. + set_default = not pre + self.no_log_values.update(set_defaults(spec, param, set_default)) def _set_fallbacks(self, spec=None, param=None): if spec is None: @@ -1946,25 +1669,7 @@ class AnsibleModule(object): if param is None: param = self.params - for (k, v) in spec.items(): - fallback = v.get('fallback', (None,)) - fallback_strategy = fallback[0] - fallback_args = [] - fallback_kwargs = {} - if k not in param and fallback_strategy is not None: - for item in fallback[1:]: - if isinstance(item, dict): - fallback_kwargs = item - else: - fallback_args = item - try: - fallback_value = fallback_strategy(*fallback_args, **fallback_kwargs) - except AnsibleFallbackNotFound: - continue - else: - if v.get('no_log', False) and fallback_value: - self.no_log_values.add(fallback_value) - param[k] = fallback_value + self.no_log_values.update(set_fallbacks(spec, param)) def _load_params(self): ''' read the input and set the params attribute. diff --git a/lib/ansible/module_utils/common/arg_spec.py b/lib/ansible/module_utils/common/arg_spec.py new file mode 100644 index 00000000000..54bf80a5871 --- /dev/null +++ b/lib/ansible/module_utils/common/arg_spec.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2021 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +from copy import deepcopy + +from ansible.module_utils.common._collections_compat import ( + Sequence, +) + +from ansible.module_utils.common.parameters import ( + get_unsupported_parameters, + handle_aliases, + list_no_log_values, + remove_values, + set_defaults, + set_fallbacks, + validate_argument_types, + validate_argument_values, + validate_sub_spec, +) + +from ansible.module_utils.common.text.converters import to_native +from ansible.module_utils.common.warnings import deprecate, warn +from ansible.module_utils.common.validation import ( + check_required_arguments, +) + +from ansible.module_utils.six import string_types + + +class ArgumentSpecValidator(): + """Argument spec validation class""" + + def __init__(self, argument_spec, parameters): + self._error_messages = [] + self._no_log_values = set() + self.argument_spec = argument_spec + # Make a copy of the original parameters to avoid changing them + self._validated_parameters = deepcopy(parameters) + self._unsupported_parameters = set() + + @property + def error_messages(self): + return self._error_messages + + @property + def validated_parameters(self): + return self._validated_parameters + + def _add_error(self, error): + if isinstance(error, string_types): + self._error_messages.append(error) + elif isinstance(error, Sequence): + self._error_messages.extend(error) + else: + raise ValueError('Error messages must be a string or sequence not a %s' % type(error)) + + def _sanitize_error_messages(self): + self._error_messages = remove_values(self._error_messages, self._no_log_values) + + def validate(self, *args, **kwargs): + """Validate module parameters against argument spec. + + :Example: + + validator = ArgumentSpecValidator(argument_spec, parameters) + passeded = validator.validate() + + :param argument_spec: Specification of parameters, type, and valid values + :type argument_spec: dict + + :param parameters: Parameters provided to the role + :type parameters: dict + + :returns: True if no errors were encountered, False if any errors were encountered. + :rtype: bool + """ + + self._no_log_values.update(set_fallbacks(self.argument_spec, self._validated_parameters)) + + alias_warnings = [] + alias_deprecations = [] + try: + alias_results, legal_inputs = handle_aliases(self.argument_spec, self._validated_parameters, alias_warnings, alias_deprecations) + except (TypeError, ValueError) as e: + alias_results = {} + legal_inputs = None + self._add_error(to_native(e)) + + for option, alias in alias_warnings: + warn('Both option %s and its alias %s are set.' % (option, alias)) + + for deprecation in alias_deprecations: + deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'], + version=deprecation.get('version'), date=deprecation.get('date'), + collection_name=deprecation.get('collection_name')) + + self._no_log_values.update(list_no_log_values(self.argument_spec, self._validated_parameters)) + + if legal_inputs is None: + legal_inputs = list(alias_results.keys()) + list(self.argument_spec.keys()) + self._unsupported_parameters.update(get_unsupported_parameters(self.argument_spec, self._validated_parameters, legal_inputs)) + + self._no_log_values.update(set_defaults(self.argument_spec, self._validated_parameters, False)) + + try: + check_required_arguments(self.argument_spec, self._validated_parameters) + except TypeError as e: + self._add_error(to_native(e)) + + validate_argument_types(self.argument_spec, self._validated_parameters, errors=self._error_messages) + validate_argument_values(self.argument_spec, self._validated_parameters, errors=self._error_messages) + + self._no_log_values.update(set_defaults(self.argument_spec, self._validated_parameters)) + + validate_sub_spec(self.argument_spec, self._validated_parameters, + errors=self._error_messages, + no_log_values=self._no_log_values, + unsupported_parameters=self._unsupported_parameters) + + if self._unsupported_parameters: + self._add_error('Unsupported parameters: %s' % ', '.join(sorted(list(self._unsupported_parameters)))) + + self._sanitize_error_messages() + + if self.error_messages: + return False + else: + return True diff --git a/lib/ansible/module_utils/common/parameters.py b/lib/ansible/module_utils/common/parameters.py index 24a82acb076..4fa5dab84c2 100644 --- a/lib/ansible/module_utils/common/parameters.py +++ b/lib/ansible/module_utils/common/parameters.py @@ -5,18 +5,44 @@ from __future__ import absolute_import, division, print_function __metaclass__ = type -from ansible.module_utils._text import to_native -from ansible.module_utils.common._collections_compat import Mapping +import datetime +import os + +from collections import deque +from itertools import chain + from ansible.module_utils.common.collections import is_iterable +from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text +from ansible.module_utils.common.text.formatters import lenient_lowercase +from ansible.module_utils.common.warnings import warn +from ansible.module_utils.parsing.convert_bool import BOOLEANS_FALSE, BOOLEANS_TRUE + +from ansible.module_utils.common._collections_compat import ( + KeysView, + Set, + Sequence, + Mapping, + MutableMapping, + MutableSet, + MutableSequence, +) from ansible.module_utils.six import ( binary_type, integer_types, string_types, text_type, + PY2, + PY3, ) from ansible.module_utils.common.validation import ( + check_mutually_exclusive, + check_required_arguments, + check_required_together, + check_required_one_of, + check_required_if, + check_required_by, check_type_bits, check_type_bool, check_type_bytes, @@ -71,6 +97,10 @@ DEFAULT_TYPE_VALIDATORS = { } +class AnsibleFallbackNotFound(Exception): + pass + + def _return_datastructure_name(obj): """ Return native stringified values from datastructures. @@ -96,11 +126,211 @@ def _return_datastructure_name(obj): raise TypeError('Unknown parameter type: %s' % (type(obj))) +def _remove_values_conditions(value, no_log_strings, deferred_removals): + """ + Helper function for :meth:`remove_values`. + + :arg value: The value to check for strings that need to be stripped + :arg no_log_strings: set of strings which must be stripped out of any values + :arg deferred_removals: List which holds information about nested + containers that have to be iterated for removals. It is passed into + this function so that more entries can be added to it if value is + a container type. The format of each entry is a 2-tuple where the first + element is the ``value`` parameter and the second value is a new + container to copy the elements of ``value`` into once iterated. + :returns: if ``value`` is a scalar, returns ``value`` with two exceptions: + 1. :class:`~datetime.datetime` objects which are changed into a string representation. + 2. objects which are in no_log_strings are replaced with a placeholder + so that no sensitive data is leaked. + If ``value`` is a container type, returns a new empty container. + + ``deferred_removals`` is added to as a side-effect of this function. + + .. warning:: It is up to the caller to make sure the order in which value + is passed in is correct. For instance, higher level containers need + to be passed in before lower level containers. For example, given + ``{'level1': {'level2': 'level3': [True]} }`` first pass in the + dictionary for ``level1``, then the dict for ``level2``, and finally + the list for ``level3``. + """ + if isinstance(value, (text_type, binary_type)): + # Need native str type + native_str_value = value + if isinstance(value, text_type): + value_is_text = True + if PY2: + native_str_value = to_bytes(value, errors='surrogate_or_strict') + elif isinstance(value, binary_type): + value_is_text = False + if PY3: + native_str_value = to_text(value, errors='surrogate_or_strict') + + if native_str_value in no_log_strings: + return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' + for omit_me in no_log_strings: + native_str_value = native_str_value.replace(omit_me, '*' * 8) + + if value_is_text and isinstance(native_str_value, binary_type): + value = to_text(native_str_value, encoding='utf-8', errors='surrogate_then_replace') + elif not value_is_text and isinstance(native_str_value, text_type): + value = to_bytes(native_str_value, encoding='utf-8', errors='surrogate_then_replace') + else: + value = native_str_value + + elif isinstance(value, Sequence): + if isinstance(value, MutableSequence): + new_value = type(value)() + else: + new_value = [] # Need a mutable value + deferred_removals.append((value, new_value)) + value = new_value + + elif isinstance(value, Set): + if isinstance(value, MutableSet): + new_value = type(value)() + else: + new_value = set() # Need a mutable value + deferred_removals.append((value, new_value)) + value = new_value + + elif isinstance(value, Mapping): + if isinstance(value, MutableMapping): + new_value = type(value)() + else: + new_value = {} # Need a mutable value + deferred_removals.append((value, new_value)) + value = new_value + + elif isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))): + stringy_value = to_native(value, encoding='utf-8', errors='surrogate_or_strict') + if stringy_value in no_log_strings: + return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' + for omit_me in no_log_strings: + if omit_me in stringy_value: + return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' + + elif isinstance(value, (datetime.datetime, datetime.date)): + value = value.isoformat() + else: + raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) + + return value + + +def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_removals): + """ Helper method to sanitize_keys() to build deferred_removals and avoid deep recursion. """ + if isinstance(value, (text_type, binary_type)): + return value + + if isinstance(value, Sequence): + if isinstance(value, MutableSequence): + new_value = type(value)() + else: + new_value = [] # Need a mutable value + deferred_removals.append((value, new_value)) + return new_value + + if isinstance(value, Set): + if isinstance(value, MutableSet): + new_value = type(value)() + else: + new_value = set() # Need a mutable value + deferred_removals.append((value, new_value)) + return new_value + + if isinstance(value, Mapping): + if isinstance(value, MutableMapping): + new_value = type(value)() + else: + new_value = {} # Need a mutable value + deferred_removals.append((value, new_value)) + return new_value + + if isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))): + return value + + if isinstance(value, (datetime.datetime, datetime.date)): + return value + + raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) + + +def env_fallback(*args, **kwargs): + """Load value from environment variable""" + + for arg in args: + if arg in os.environ: + return os.environ[arg] + raise AnsibleFallbackNotFound + + +def set_fallbacks(argument_spec, parameters): + no_log_values = set() + for param, value in argument_spec.items(): + fallback = value.get('fallback', (None,)) + fallback_strategy = fallback[0] + fallback_args = [] + fallback_kwargs = {} + if param not in parameters and fallback_strategy is not None: + for item in fallback[1:]: + if isinstance(item, dict): + fallback_kwargs = item + else: + fallback_args = item + try: + fallback_value = fallback_strategy(*fallback_args, **fallback_kwargs) + except AnsibleFallbackNotFound: + continue + else: + if value.get('no_log', False) and fallback_value: + no_log_values.add(fallback_value) + parameters[param] = fallback_value + + return no_log_values + + +def set_defaults(argument_spec, parameters, set_default=True): + """Set default values for parameters when no value is supplied. + + Modifies parameters directly. + + :param argument_spec: Argument spec + :type argument_spec: dict + + :param parameters: Parameters to evaluate + :type parameters: dict + + :param set_default: Whether or not to set the default values + :type set_default: bool + + :returns: Set of strings that should not be logged. + :rtype: set + """ + + no_log_values = set() + for param, value in argument_spec.items(): + + # TODO: Change the default value from None to Sentinel to differentiate between + # user supplied None and a default value set by this function. + default = value.get('default', None) + + # This prevents setting defaults on required items on the 1st run, + # otherwise will set things without a default to None on the 2nd. + if param not in parameters and (default is not None or set_default): + # Make sure any default value for no_log fields are masked. + if value.get('no_log', False) and default: + no_log_values.add(default) + + parameters[param] = default + + return no_log_values + + def list_no_log_values(argument_spec, params): """Return set of no log values :arg argument_spec: An argument spec dictionary from a module - :arg params: Dictionary of all module parameters + :arg params: Dictionary of all parameters :returns: Set of strings that should be hidden from output:: @@ -146,11 +376,11 @@ def list_no_log_values(argument_spec, params): return no_log_values -def list_deprecations(argument_spec, params, prefix=''): +def list_deprecations(argument_spec, parameters, prefix=''): """Return a list of deprecations :arg argument_spec: An argument spec dictionary from a module - :arg params: Dictionary of all module parameters + :arg parameters: Dictionary of parameters :returns: List of dictionaries containing a message and version in which the deprecated parameter will be removed, or an empty list:: @@ -160,7 +390,7 @@ def list_deprecations(argument_spec, params, prefix=''): deprecations = [] for arg_name, arg_opts in argument_spec.items(): - if arg_name in params: + if arg_name in parameters: if prefix: sub_prefix = '%s["%s"]' % (prefix, arg_name) else: @@ -180,7 +410,7 @@ def list_deprecations(argument_spec, params, prefix=''): # Check sub-argument spec sub_argument_spec = arg_opts.get('options') if sub_argument_spec is not None: - sub_arguments = params[arg_name] + sub_arguments = parameters[arg_name] if isinstance(sub_arguments, Mapping): sub_arguments = [sub_arguments] if isinstance(sub_arguments, list): @@ -191,12 +421,94 @@ def list_deprecations(argument_spec, params, prefix=''): return deprecations -def handle_aliases(argument_spec, params, alias_warnings=None): +def sanitize_keys(obj, no_log_strings, ignore_keys=frozenset()): + """ Sanitize the keys in a container object by removing no_log values from key names. + + This is a companion function to the `remove_values()` function. Similar to that function, + we make use of deferred_removals to avoid hitting maximum recursion depth in cases of + large data structures. + + :param obj: The container object to sanitize. Non-container objects are returned unmodified. + :param no_log_strings: A set of string values we do not want logged. + :param ignore_keys: A set of string values of keys to not sanitize. + + :returns: An object with sanitized keys. + """ + + deferred_removals = deque() + + no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings] + new_value = _sanitize_keys_conditions(obj, no_log_strings, ignore_keys, deferred_removals) + + while deferred_removals: + old_data, new_data = deferred_removals.popleft() + + if isinstance(new_data, Mapping): + for old_key, old_elem in old_data.items(): + if old_key in ignore_keys or old_key.startswith('_ansible'): + new_data[old_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals) + else: + # Sanitize the old key. We take advantage of the sanitizing code in + # _remove_values_conditions() rather than recreating it here. + new_key = _remove_values_conditions(old_key, no_log_strings, None) + new_data[new_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals) + else: + for elem in old_data: + new_elem = _sanitize_keys_conditions(elem, no_log_strings, ignore_keys, deferred_removals) + if isinstance(new_data, MutableSequence): + new_data.append(new_elem) + elif isinstance(new_data, MutableSet): + new_data.add(new_elem) + else: + raise TypeError('Unknown container type encountered when removing private values from keys') + + return new_value + + +def remove_values(value, no_log_strings): + """ Remove strings in no_log_strings from value. If value is a container + type, then remove a lot more. + + Use of deferred_removals exists, rather than a pure recursive solution, + because of the potential to hit the maximum recursion depth when dealing with + large amounts of data (see issue #24560). + """ + + deferred_removals = deque() + + no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings] + new_value = _remove_values_conditions(value, no_log_strings, deferred_removals) + + while deferred_removals: + old_data, new_data = deferred_removals.popleft() + if isinstance(new_data, Mapping): + for old_key, old_elem in old_data.items(): + new_elem = _remove_values_conditions(old_elem, no_log_strings, deferred_removals) + new_data[old_key] = new_elem + else: + for elem in old_data: + new_elem = _remove_values_conditions(elem, no_log_strings, deferred_removals) + if isinstance(new_data, MutableSequence): + new_data.append(new_elem) + elif isinstance(new_data, MutableSet): + new_data.add(new_elem) + else: + raise TypeError('Unknown container type encountered when removing private values from output') + + return new_value + + +def handle_aliases(argument_spec, parameters, alias_warnings=None, alias_deprecations=None): """Return a two item tuple. The first is a dictionary of aliases, the second is a list of legal inputs. + Modify supplied parameters by adding a new key for each alias. + If a list is provided to the alias_warnings parameter, it will be filled with tuples (option, alias) in every case where both an option and its alias are specified. + + If a list is provided to alias_deprecations, it will be populated with dictionaries, + each containing deprecation information for each alias found in argument_spec. """ legal_inputs = ['_ansible_%s' % k for k in PASS_VARS] @@ -207,31 +519,40 @@ def handle_aliases(argument_spec, params, alias_warnings=None): aliases = v.get('aliases', None) default = v.get('default', None) required = v.get('required', False) + + if alias_deprecations is not None: + for alias in argument_spec[k].get('deprecated_aliases', []): + if alias.get('name') in parameters: + alias_deprecations.append(alias) + if default is not None and required: # not alias specific but this is a good place to check this raise ValueError("internal error: required and default are mutually exclusive for %s" % k) + if aliases is None: continue + if not is_iterable(aliases) or isinstance(aliases, (binary_type, text_type)): raise TypeError('internal error: aliases must be a list or tuple') + for alias in aliases: legal_inputs.append(alias) aliases_results[alias] = k - if alias in params: - if k in params and alias_warnings is not None: + if alias in parameters: + if k in parameters and alias_warnings is not None: alias_warnings.append((k, alias)) - params[k] = params[alias] + parameters[k] = parameters[alias] return aliases_results, legal_inputs -def get_unsupported_parameters(argument_spec, module_parameters, legal_inputs=None): - """Check keys in module_parameters against those provided in legal_inputs +def get_unsupported_parameters(argument_spec, parameters, legal_inputs=None): + """Check keys in parameters against those provided in legal_inputs to ensure they contain legal values. If legal_inputs are not supplied, they will be generated using the argument_spec. :arg argument_spec: Dictionary of parameters, their type, and valid values. - :arg module_parameters: Dictionary of module parameters. + :arg parameters: Dictionary of parameters. :arg legal_inputs: List of valid key names property names. Overrides values in argument_spec. @@ -240,10 +561,10 @@ def get_unsupported_parameters(argument_spec, module_parameters, legal_inputs=No """ if legal_inputs is None: - aliases, legal_inputs = handle_aliases(argument_spec, module_parameters) + aliases, legal_inputs = handle_aliases(argument_spec, parameters) unsupported_parameters = set() - for k in module_parameters.keys(): + for k in parameters.keys(): if k not in legal_inputs: unsupported_parameters.add(k) @@ -275,3 +596,256 @@ def get_type_validator(wanted): wanted = getattr(wanted, '__name__', to_native(type(wanted))) return type_checker, wanted + + +def validate_elements(wanted_type, parameter, values, options_context=None, errors=None): + + if errors is None: + errors = [] + + type_checker, wanted_element_type = get_type_validator(wanted_type) + validated_parameters = [] + # Get param name for strings so we can later display this value in a useful error message if needed + # Only pass 'kwargs' to our checkers and ignore custom callable checkers + kwargs = {} + if wanted_element_type == 'str' and isinstance(wanted_type, string_types): + if isinstance(parameter, string_types): + kwargs['param'] = parameter + elif isinstance(parameter, dict): + kwargs['param'] = list(parameter.keys())[0] + + for value in values: + try: + validated_parameters.append(type_checker(value, **kwargs)) + except (TypeError, ValueError) as e: + msg = "Elements value for option '%s'" % parameter + if options_context: + msg += " found in '%s'" % " -> ".join(options_context) + msg += " is of type %s and we were unable to convert to %s: %s" % (type(value), wanted_element_type, to_native(e)) + errors.append(msg) + return validated_parameters + + +def validate_argument_types(argument_spec, parameters, prefix='', options_context=None, errors=None): + """Validate that parameter types match the type in the argument spec. + + Determine the appropriate type checker function and run each + parameter value through that function. All error messages from type checker + functions are returned. If any parameter fails to validate, it will not + be in the returned parameters. + + :param argument_spec: Argument spec + :type argument_spec: dict + + :param parameters: Parameters passed to module + :type parameters: dict + + :param prefix: Name of the parent key that contains the spec. Used in the error message + :type prefix: str + + :param options_context: List of contexts? + :type options_context: list + + :returns: Two item tuple containing validated and coerced parameters + and a list of any errors that were encountered. + :rtype: tuple + + """ + + if errors is None: + errors = [] + + for param, spec in argument_spec.items(): + if param not in parameters: + continue + + value = parameters[param] + if value is None: + continue + + wanted_type = spec.get('type') + type_checker, wanted_name = get_type_validator(wanted_type) + # Get param name for strings so we can later display this value in a useful error message if needed + # Only pass 'kwargs' to our checkers and ignore custom callable checkers + kwargs = {} + if wanted_name == 'str' and isinstance(wanted_type, string_types): + kwargs['param'] = list(parameters.keys())[0] + + # Get the name of the parent key if this is a nested option + if prefix: + kwargs['prefix'] = prefix + + try: + parameters[param] = type_checker(value, **kwargs) + elements_wanted_type = spec.get('elements', None) + if elements_wanted_type: + elements = parameters[param] + if wanted_type != 'list' or not isinstance(elements, list): + msg = "Invalid type %s for option '%s'" % (wanted_name, elements) + if options_context: + msg += " found in '%s'." % " -> ".join(options_context) + msg += ", elements value check is supported only with 'list' type" + errors.append(msg) + parameters[param] = validate_elements(elements_wanted_type, param, elements, options_context, errors) + + except (TypeError, ValueError) as e: + msg = "argument '%s' is of type %s" % (param, type(value)) + if options_context: + msg += " found in '%s'." % " -> ".join(options_context) + msg += " and we were unable to convert to %s: %s" % (wanted_name, to_native(e)) + errors.append(msg) + + +def validate_argument_values(argument_spec, parameters, options_context=None, errors=None): + """Ensure all arguments have the requested values, and there are no stray arguments""" + + if errors is None: + errors = [] + + for param, spec in argument_spec.items(): + choices = spec.get('choices') + if choices is None: + continue + + if isinstance(choices, (frozenset, KeysView, Sequence)) and not isinstance(choices, (binary_type, text_type)): + if param in parameters: + # Allow one or more when type='list' param with choices + if isinstance(parameters[param], list): + diff_list = ", ".join([item for item in parameters[param] if item not in choices]) + if diff_list: + choices_str = ", ".join([to_native(c) for c in choices]) + msg = "value of %s must be one or more of: %s. Got no match for: %s" % (param, choices_str, diff_list) + if options_context: + msg += " found in %s" % " -> ".join(options_context) + errors.append(msg) + elif parameters[param] not in choices: + # PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking + # the value. If we can't figure this out, module author is responsible. + lowered_choices = None + if parameters[param] == 'False': + lowered_choices = lenient_lowercase(choices) + overlap = BOOLEANS_FALSE.intersection(choices) + if len(overlap) == 1: + # Extract from a set + (parameters[param],) = overlap + + if parameters[param] == 'True': + if lowered_choices is None: + lowered_choices = lenient_lowercase(choices) + overlap = BOOLEANS_TRUE.intersection(choices) + if len(overlap) == 1: + (parameters[param],) = overlap + + if parameters[param] not in choices: + choices_str = ", ".join([to_native(c) for c in choices]) + msg = "value of %s must be one of: %s, got: %s" % (param, choices_str, parameters[param]) + if options_context: + msg += " found in %s" % " -> ".join(options_context) + errors.append(msg) + else: + msg = "internal error: choices for argument %s are not iterable: %s" % (param, choices) + if options_context: + msg += " found in %s" % " -> ".join(options_context) + errors.append(msg) + + +def validate_sub_spec(argument_spec, parameters, prefix='', options_context=None, errors=None, no_log_values=None, unsupported_parameters=None): + """Validate sub argument spec. This function is recursive.""" + + if options_context is None: + options_context = [] + + if errors is None: + errors = [] + + if no_log_values is None: + no_log_values = set() + + if unsupported_parameters is None: + unsupported_parameters = set() + + for param, value in argument_spec.items(): + wanted = value.get('type') + if wanted == 'dict' or (wanted == 'list' and value.get('elements', '') == dict): + sub_spec = value.get('options') + if value.get('apply_defaults', False): + if sub_spec is not None: + if parameters.get(value) is None: + parameters[param] = {} + else: + continue + elif sub_spec is None or param not in parameters or parameters[param] is None: + continue + + # Keep track of context for warning messages + options_context.append(param) + + # Make sure we can iterate over the elements + if isinstance(parameters[param], dict): + elements = [parameters[param]] + else: + elements = parameters[param] + + for idx, sub_parameters in enumerate(elements): + if not isinstance(sub_parameters, dict): + errors.append("value of '%s' must be of type dict or list of dicts" % param) + + # Set prefix for warning messages + new_prefix = prefix + param + if wanted == 'list': + new_prefix += '[%d]' % idx + new_prefix += '.' + + no_log_values.update(set_fallbacks(sub_spec, sub_parameters)) + + alias_warnings = [] + try: + options_aliases, legal_inputs = handle_aliases(sub_spec, sub_parameters, alias_warnings) + except (TypeError, ValueError) as e: + options_aliases = {} + legal_inputs = None + errors.append(to_native(e)) + + for option, alias in alias_warnings: + warn('Both option %s and its alias %s are set.' % (option, alias)) + + no_log_values.update(list_no_log_values(sub_spec, sub_parameters)) + + if legal_inputs is None: + legal_inputs = list(options_aliases.keys()) + list(sub_spec.keys()) + unsupported_parameters.update(get_unsupported_parameters(sub_spec, sub_parameters, legal_inputs)) + + try: + check_mutually_exclusive(value.get('mutually_exclusive'), sub_parameters) + except TypeError as e: + errors.append(to_native(e)) + + no_log_values.update(set_defaults(sub_spec, sub_parameters, False)) + + try: + check_required_arguments(sub_spec, sub_parameters) + except TypeError as e: + errors.append(to_native(e)) + + validate_argument_types(sub_spec, sub_parameters, new_prefix, options_context, errors=errors) + validate_argument_values(sub_spec, sub_parameters, options_context, errors=errors) + + checks = [ + (check_required_together, 'required_together'), + (check_required_one_of, 'required_one_of'), + (check_required_if, 'required_if'), + (check_required_by, 'required_by'), + ] + + for check in checks: + try: + check[0](value.get(check[1]), parameters) + except TypeError as e: + errors.append(to_native(e)) + + no_log_values.update(set_defaults(sub_spec, sub_parameters)) + + # Handle nested specs + validate_sub_spec(sub_spec, sub_parameters, new_prefix, options_context, errors, no_log_values, unsupported_parameters) + + options_context.pop() diff --git a/lib/ansible/module_utils/common/validation.py b/lib/ansible/module_utils/common/validation.py index a53a39458c0..df409059873 100644 --- a/lib/ansible/module_utils/common/validation.py +++ b/lib/ansible/module_utils/common/validation.py @@ -23,11 +23,11 @@ from ansible.module_utils.six import ( ) -def count_terms(terms, module_parameters): +def count_terms(terms, parameters): """Count the number of occurrences of a key in a given dictionary :arg terms: String or iterable of values to check - :arg module_parameters: Dictionary of module parameters + :arg parameters: Dictionary of parameters :returns: An integer that is the number of occurrences of the terms values in the provided dictionary. @@ -36,17 +36,17 @@ def count_terms(terms, module_parameters): if not is_iterable(terms): terms = [terms] - return len(set(terms).intersection(module_parameters)) + return len(set(terms).intersection(parameters)) -def check_mutually_exclusive(terms, module_parameters): +def check_mutually_exclusive(terms, parameters): """Check mutually exclusive terms against argument parameters Accepts a single list or list of lists that are groups of terms that should be mutually exclusive with one another - :arg terms: List of mutually exclusive module parameters - :arg module_parameters: Dictionary of module parameters + :arg terms: List of mutually exclusive parameters + :arg parameters: Dictionary of parameters :returns: Empty list or raises TypeError if the check fails. """ @@ -56,7 +56,7 @@ def check_mutually_exclusive(terms, module_parameters): return results for check in terms: - count = count_terms(check, module_parameters) + count = count_terms(check, parameters) if count > 1: results.append(check) @@ -68,7 +68,7 @@ def check_mutually_exclusive(terms, module_parameters): return results -def check_required_one_of(terms, module_parameters): +def check_required_one_of(terms, parameters): """Check each list of terms to ensure at least one exists in the given module parameters @@ -76,7 +76,7 @@ def check_required_one_of(terms, module_parameters): :arg terms: List of lists of terms to check. For each list of terms, at least one is required. - :arg module_parameters: Dictionary of module parameters + :arg parameters: Dictionary of parameters :returns: Empty list or raises TypeError if the check fails. """ @@ -86,7 +86,7 @@ def check_required_one_of(terms, module_parameters): return results for term in terms: - count = count_terms(term, module_parameters) + count = count_terms(term, parameters) if count == 0: results.append(term) @@ -98,16 +98,16 @@ def check_required_one_of(terms, module_parameters): return results -def check_required_together(terms, module_parameters): +def check_required_together(terms, parameters): """Check each list of terms to ensure every parameter in each list exists - in the given module parameters + in the given parameters Accepts a list of lists or tuples :arg terms: List of lists of terms to check. Each list should include parameters that are all required when at least one is specified - in the module_parameters. - :arg module_parameters: Dictionary of module parameters + in the parameters. + :arg parameters: Dictionary of parameters :returns: Empty list or raises TypeError if the check fails. """ @@ -117,7 +117,7 @@ def check_required_together(terms, module_parameters): return results for term in terms: - counts = [count_terms(field, module_parameters) for field in term] + counts = [count_terms(field, parameters) for field in term] non_zero = [c for c in counts if c > 0] if len(non_zero) > 0: if 0 in counts: @@ -130,14 +130,14 @@ def check_required_together(terms, module_parameters): return results -def check_required_by(requirements, module_parameters): +def check_required_by(requirements, parameters): """For each key in requirements, check the corresponding list to see if they - exist in module_parameters + exist in parameters Accepts a single string or list of values for each key :arg requirements: Dictionary of requirements - :arg module_parameters: Dictionary of module parameters + :arg parameters: Dictionary of parameters :returns: Empty dictionary or raises TypeError if the """ @@ -147,14 +147,14 @@ def check_required_by(requirements, module_parameters): return result for (key, value) in requirements.items(): - if key not in module_parameters or module_parameters[key] is None: + if key not in parameters or parameters[key] is None: continue result[key] = [] # Support strings (single-item lists) if isinstance(value, string_types): value = [value] for required in value: - if required not in module_parameters or module_parameters[required] is None: + if required not in parameters or parameters[required] is None: result[key].append(required) if result: @@ -166,15 +166,15 @@ def check_required_by(requirements, module_parameters): return result -def check_required_arguments(argument_spec, module_parameters): +def check_required_arguments(argument_spec, parameters): """Check all paramaters in argument_spec and return a list of parameters - that are required but not present in module_parameters + that are required but not present in parameters Raises TypeError if the check fails :arg argument_spec: Argument spec dicitionary containing all parameters and their specification - :arg module_paramaters: Dictionary of module parameters + :arg module_paramaters: Dictionary of parameters :returns: Empty list or raises TypeError if the check fails. """ @@ -185,7 +185,7 @@ def check_required_arguments(argument_spec, module_parameters): for (k, v) in argument_spec.items(): required = v.get('required', False) - if required and k not in module_parameters: + if required and k not in parameters: missing.append(k) if missing: @@ -195,7 +195,7 @@ def check_required_arguments(argument_spec, module_parameters): return missing -def check_required_if(requirements, module_parameters): +def check_required_if(requirements, parameters): """Check parameters that are conditionally required Raises TypeError if the check fails @@ -210,7 +210,7 @@ def check_required_if(requirements, module_parameters): ['someint', 99, ('bool_param', 'string_param')], ] - :arg module_paramaters: Dictionary of module parameters + :arg module_paramaters: Dictionary of parameters :returns: Empty list or raises TypeError if the check fails. The results attribute of the exception contains a list of dictionaries. @@ -257,9 +257,9 @@ def check_required_if(requirements, module_parameters): else: missing['requires'] = 'all' - if key in module_parameters and module_parameters[key] == val: + if key in parameters and parameters[key] == val: for check in requirements: - count = count_terms(check, module_parameters) + count = count_terms(check, parameters) if count == 0: missing['missing'].append(check) if len(missing['missing']) and len(missing['missing']) >= max_missing_count: @@ -277,13 +277,13 @@ def check_required_if(requirements, module_parameters): return results -def check_missing_parameters(module_parameters, required_parameters=None): +def check_missing_parameters(parameters, required_parameters=None): """This is for checking for required params when we can not check via argspec because we need more information than is simply given in the argspec. Raises TypeError if any required parameters are missing - :arg module_paramaters: Dictionary of module parameters + :arg module_paramaters: Dictionary of parameters :arg required_parameters: List of parameters to look for in the given module parameters @@ -294,7 +294,7 @@ def check_missing_parameters(module_parameters, required_parameters=None): return missing_params for param in required_parameters: - if not module_parameters.get(param): + if not parameters.get(param): missing_params.append(param) if missing_params: @@ -332,7 +332,10 @@ def safe_eval(value, locals=None, include_exceptions=False): return value -def check_type_str(value, allow_conversion=True): +# FIXME: The param and prefix parameters here are coming from AnsibleModule._check_type_string() +# which is using those for the warning messaged based on string conversion warning settings. +# Not sure how to deal with that here since we don't have config state to query. +def check_type_str(value, allow_conversion=True, param=None, prefix=''): """Verify that the value is a string or convert to a string. Since unexpected changes can sometimes happen when converting to a string, diff --git a/test/units/module_utils/common/arg_spec/__init__.py b/test/units/module_utils/common/arg_spec/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/units/module_utils/common/arg_spec/test_sub_spec.py b/test/units/module_utils/common/arg_spec/test_sub_spec.py new file mode 100644 index 00000000000..eaa775fdf53 --- /dev/null +++ b/test/units/module_utils/common/arg_spec/test_sub_spec.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2021 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from ansible.module_utils.common.arg_spec import ArgumentSpecValidator + + +def test_sub_spec(): + arg_spec = { + 'state': {}, + 'user': { + 'type': 'dict', + 'options': { + 'first': {'no_log': True}, + 'last': {}, + 'age': {'type': 'int'}, + } + } + } + + parameters = { + 'state': 'present', + 'user': { + 'first': 'Rey', + 'last': 'Skywalker', + 'age': '19', + } + } + + expected = { + 'state': 'present', + 'user': { + 'first': 'Rey', + 'last': 'Skywalker', + 'age': 19, + } + } + + v = ArgumentSpecValidator(arg_spec, parameters) + passed = v.validate() + + assert passed is True + assert v.error_messages == [] + assert v.validated_parameters == expected + + +def test_nested_sub_spec(): + arg_spec = { + 'type': {}, + 'car': { + 'type': 'dict', + 'options': { + 'make': {}, + 'model': {}, + 'customizations': { + 'type': 'dict', + 'options': { + 'engine': {}, + 'transmission': {}, + 'color': {}, + 'max_rpm': {'type': 'int'}, + } + } + } + } + } + + parameters = { + 'type': 'endurance', + 'car': { + 'make': 'Ford', + 'model': 'GT-40', + 'customizations': { + 'engine': '7.0 L', + 'transmission': '5-speed', + 'color': 'Ford blue', + 'max_rpm': '6000', + } + + } + } + + expected = { + 'type': 'endurance', + 'car': { + 'make': 'Ford', + 'model': 'GT-40', + 'customizations': { + 'engine': '7.0 L', + 'transmission': '5-speed', + 'color': 'Ford blue', + 'max_rpm': 6000, + } + + } + } + + v = ArgumentSpecValidator(arg_spec, parameters) + passed = v.validate() + + assert passed is True + assert v.error_messages == [] + assert v.validated_parameters == expected diff --git a/test/units/module_utils/common/arg_spec/test_validate_aliases.py b/test/units/module_utils/common/arg_spec/test_validate_aliases.py new file mode 100644 index 00000000000..a2a36cf4c39 --- /dev/null +++ b/test/units/module_utils/common/arg_spec/test_validate_aliases.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2021 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from ansible.module_utils.common.arg_spec import ArgumentSpecValidator +from ansible.module_utils.common.warnings import get_deprecation_messages + + +def test_spec_with_aliases(): + arg_spec = { + 'path': {'aliases': ['dir', 'directory']} + } + + parameters = { + 'dir': '/tmp', + 'directory': '/tmp', + } + + expected = { + 'dir': '/tmp', + 'directory': '/tmp', + 'path': '/tmp', + } + + v = ArgumentSpecValidator(arg_spec, parameters) + passed = v.validate() + + assert passed is True + assert v.validated_parameters == expected + + +def test_alias_deprecation(): + arg_spec = { + 'path': { + 'aliases': ['not_yo_path'], + 'deprecated_aliases': [{ + 'name': 'not_yo_path', + 'version': '1.7', + }] + } + } + + parameters = { + 'not_yo_path': '/tmp', + } + + expected = { + 'path': '/tmp', + 'not_yo_path': '/tmp', + } + + v = ArgumentSpecValidator(arg_spec, parameters) + passed = v.validate() + + assert passed is True + assert v.validated_parameters == expected + assert v.error_messages == [] + assert "Alias 'not_yo_path' is deprecated." in get_deprecation_messages()[0]['msg'] diff --git a/test/units/module_utils/common/arg_spec/test_validate_basic.py b/test/units/module_utils/common/arg_spec/test_validate_basic.py new file mode 100644 index 00000000000..344fb2b38c7 --- /dev/null +++ b/test/units/module_utils/common/arg_spec/test_validate_basic.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2021 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from ansible.module_utils.common.arg_spec import ArgumentSpecValidator + + +def test_basic_spec(): + arg_spec = { + 'param_str': {'type': 'str'}, + 'param_list': {'type': 'list'}, + 'param_dict': {'type': 'dict'}, + 'param_bool': {'type': 'bool'}, + 'param_int': {'type': 'int'}, + 'param_float': {'type': 'float'}, + 'param_path': {'type': 'path'}, + 'param_raw': {'type': 'raw'}, + 'param_bytes': {'type': 'bytes'}, + 'param_bits': {'type': 'bits'}, + + } + + parameters = { + 'param_str': 22, + 'param_list': 'one,two,three', + 'param_dict': 'first=star,last=lord', + 'param_bool': True, + 'param_int': 22, + 'param_float': 1.5, + 'param_path': '/tmp', + 'param_raw': 'raw', + 'param_bytes': '2K', + 'param_bits': '1Mb', + } + + expected = { + 'param_str': '22', + 'param_list': ['one', 'two', 'three'], + 'param_dict': {'first': 'star', 'last': 'lord'}, + 'param_bool': True, + 'param_float': 1.5, + 'param_int': 22, + 'param_path': '/tmp', + 'param_raw': 'raw', + 'param_bits': 1048576, + 'param_bytes': 2048, + } + + v = ArgumentSpecValidator(arg_spec, parameters) + passed = v.validate() + + assert passed is True + assert v.validated_parameters == expected + assert v.error_messages == [] + + +def test_spec_with_defaults(): + arg_spec = { + 'param_str': {'type': 'str', 'default': 'DEFAULT'}, + } + + parameters = {} + + expected = { + 'param_str': 'DEFAULT', + } + + v = ArgumentSpecValidator(arg_spec, parameters) + passed = v.validate() + + assert passed is True + assert v.validated_parameters == expected + assert v.error_messages == [] + + +def test_spec_with_elements(): + arg_spec = { + 'param_list': { + 'type': 'list', + 'elements': 'int', + } + } + + parameters = { + 'param_list': [55, 33, 34, '22'], + } + + expected = { + 'param_list': [55, 33, 34, 22], + } + + v = ArgumentSpecValidator(arg_spec, parameters) + passed = v.validate() + + assert passed is True + assert v.error_messages == [] + assert v.validated_parameters == expected diff --git a/test/units/module_utils/common/arg_spec/test_validate_failures.py b/test/units/module_utils/common/arg_spec/test_validate_failures.py new file mode 100644 index 00000000000..e0af0159e35 --- /dev/null +++ b/test/units/module_utils/common/arg_spec/test_validate_failures.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2021 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from ansible.module_utils.common.arg_spec import ArgumentSpecValidator + + +def test_required_and_default(): + arg_spec = { + 'param_req': {'required': True, 'default': 'DEFAULT'}, + } + + v = ArgumentSpecValidator(arg_spec, {}) + passed = v.validate() + + expected = { + 'param_req': 'DEFAULT' + } + + expected_errors = [ + 'internal error: required and default are mutually exclusive for param_req', + ] + + assert passed is False + assert v.validated_parameters == expected + assert v.error_messages == expected_errors