Add argument spec validator (#73335)

Add argument spec validator class
This commit is contained in:
Sam Doran 2021-02-11 19:17:14 -05:00 committed by GitHub
parent 2b227203a2
commit b6811dfb61
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1083 additions and 367 deletions

View file

@ -0,0 +1,4 @@
major_changes:
- >-
add ``ArgumentSpecValidator`` class for validating parameters against an
argument spec outside of ``AnsibleModule`` (https://github.com/ansible/ansible/pull/73335)

View file

@ -55,7 +55,6 @@ import time
import traceback import traceback
import types import types
from collections import deque
from itertools import chain, repeat from itertools import chain, repeat
try: try:
@ -156,11 +155,20 @@ from ansible.module_utils.common.sys_info import (
) )
from ansible.module_utils.pycompat24 import get_exception, literal_eval from ansible.module_utils.pycompat24 import get_exception, literal_eval
from ansible.module_utils.common.parameters import ( from ansible.module_utils.common.parameters import (
_remove_values_conditions,
_sanitize_keys_conditions,
sanitize_keys,
env_fallback,
get_unsupported_parameters, get_unsupported_parameters,
get_type_validator, get_type_validator,
handle_aliases, handle_aliases,
list_deprecations, list_deprecations,
list_no_log_values, list_no_log_values,
remove_values,
set_defaults,
set_fallbacks,
validate_argument_types,
AnsibleFallbackNotFound,
DEFAULT_TYPE_VALIDATORS, DEFAULT_TYPE_VALIDATORS,
PASS_VARS, PASS_VARS,
PASS_BOOLS, PASS_BOOLS,
@ -241,14 +249,6 @@ _literal_eval = literal_eval
_ANSIBLE_ARGS = None _ANSIBLE_ARGS = None
def env_fallback(*args, **kwargs):
''' Load value from environment '''
for arg in args:
if arg in os.environ:
return os.environ[arg]
raise AnsibleFallbackNotFound
FILE_COMMON_ARGUMENTS = dict( FILE_COMMON_ARGUMENTS = dict(
# These are things we want. About setting metadata (mode, ownership, permissions in general) on # These are things we want. About setting metadata (mode, ownership, permissions in general) on
# created files (these are used by set_fs_attributes_if_different and included in # created files (these are used by set_fs_attributes_if_different and included in
@ -320,212 +320,6 @@ def get_all_subclasses(cls):
# End compat shims # End compat shims
def _remove_values_conditions(value, no_log_strings, deferred_removals):
"""
Helper function for :meth:`remove_values`.
:arg value: The value to check for strings that need to be stripped
:arg no_log_strings: set of strings which must be stripped out of any values
:arg deferred_removals: List which holds information about nested
containers that have to be iterated for removals. It is passed into
this function so that more entries can be added to it if value is
a container type. The format of each entry is a 2-tuple where the first
element is the ``value`` parameter and the second value is a new
container to copy the elements of ``value`` into once iterated.
:returns: if ``value`` is a scalar, returns ``value`` with two exceptions:
1. :class:`~datetime.datetime` objects which are changed into a string representation.
2. objects which are in no_log_strings are replaced with a placeholder
so that no sensitive data is leaked.
If ``value`` is a container type, returns a new empty container.
``deferred_removals`` is added to as a side-effect of this function.
.. warning:: It is up to the caller to make sure the order in which value
is passed in is correct. For instance, higher level containers need
to be passed in before lower level containers. For example, given
``{'level1': {'level2': 'level3': [True]} }`` first pass in the
dictionary for ``level1``, then the dict for ``level2``, and finally
the list for ``level3``.
"""
if isinstance(value, (text_type, binary_type)):
# Need native str type
native_str_value = value
if isinstance(value, text_type):
value_is_text = True
if PY2:
native_str_value = to_bytes(value, errors='surrogate_or_strict')
elif isinstance(value, binary_type):
value_is_text = False
if PY3:
native_str_value = to_text(value, errors='surrogate_or_strict')
if native_str_value in no_log_strings:
return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER'
for omit_me in no_log_strings:
native_str_value = native_str_value.replace(omit_me, '*' * 8)
if value_is_text and isinstance(native_str_value, binary_type):
value = to_text(native_str_value, encoding='utf-8', errors='surrogate_then_replace')
elif not value_is_text and isinstance(native_str_value, text_type):
value = to_bytes(native_str_value, encoding='utf-8', errors='surrogate_then_replace')
else:
value = native_str_value
elif isinstance(value, Sequence):
if isinstance(value, MutableSequence):
new_value = type(value)()
else:
new_value = [] # Need a mutable value
deferred_removals.append((value, new_value))
value = new_value
elif isinstance(value, Set):
if isinstance(value, MutableSet):
new_value = type(value)()
else:
new_value = set() # Need a mutable value
deferred_removals.append((value, new_value))
value = new_value
elif isinstance(value, Mapping):
if isinstance(value, MutableMapping):
new_value = type(value)()
else:
new_value = {} # Need a mutable value
deferred_removals.append((value, new_value))
value = new_value
elif isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))):
stringy_value = to_native(value, encoding='utf-8', errors='surrogate_or_strict')
if stringy_value in no_log_strings:
return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER'
for omit_me in no_log_strings:
if omit_me in stringy_value:
return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER'
elif isinstance(value, (datetime.datetime, datetime.date)):
value = value.isoformat()
else:
raise TypeError('Value of unknown type: %s, %s' % (type(value), value))
return value
def remove_values(value, no_log_strings):
""" Remove strings in no_log_strings from value. If value is a container
type, then remove a lot more.
Use of deferred_removals exists, rather than a pure recursive solution,
because of the potential to hit the maximum recursion depth when dealing with
large amounts of data (see issue #24560).
"""
deferred_removals = deque()
no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings]
new_value = _remove_values_conditions(value, no_log_strings, deferred_removals)
while deferred_removals:
old_data, new_data = deferred_removals.popleft()
if isinstance(new_data, Mapping):
for old_key, old_elem in old_data.items():
new_elem = _remove_values_conditions(old_elem, no_log_strings, deferred_removals)
new_data[old_key] = new_elem
else:
for elem in old_data:
new_elem = _remove_values_conditions(elem, no_log_strings, deferred_removals)
if isinstance(new_data, MutableSequence):
new_data.append(new_elem)
elif isinstance(new_data, MutableSet):
new_data.add(new_elem)
else:
raise TypeError('Unknown container type encountered when removing private values from output')
return new_value
def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_removals):
""" Helper method to sanitize_keys() to build deferred_removals and avoid deep recursion. """
if isinstance(value, (text_type, binary_type)):
return value
if isinstance(value, Sequence):
if isinstance(value, MutableSequence):
new_value = type(value)()
else:
new_value = [] # Need a mutable value
deferred_removals.append((value, new_value))
return new_value
if isinstance(value, Set):
if isinstance(value, MutableSet):
new_value = type(value)()
else:
new_value = set() # Need a mutable value
deferred_removals.append((value, new_value))
return new_value
if isinstance(value, Mapping):
if isinstance(value, MutableMapping):
new_value = type(value)()
else:
new_value = {} # Need a mutable value
deferred_removals.append((value, new_value))
return new_value
if isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))):
return value
if isinstance(value, (datetime.datetime, datetime.date)):
return value
raise TypeError('Value of unknown type: %s, %s' % (type(value), value))
def sanitize_keys(obj, no_log_strings, ignore_keys=frozenset()):
""" Sanitize the keys in a container object by removing no_log values from key names.
This is a companion function to the `remove_values()` function. Similar to that function,
we make use of deferred_removals to avoid hitting maximum recursion depth in cases of
large data structures.
:param obj: The container object to sanitize. Non-container objects are returned unmodified.
:param no_log_strings: A set of string values we do not want logged.
:param ignore_keys: A set of string values of keys to not sanitize.
:returns: An object with sanitized keys.
"""
deferred_removals = deque()
no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings]
new_value = _sanitize_keys_conditions(obj, no_log_strings, ignore_keys, deferred_removals)
while deferred_removals:
old_data, new_data = deferred_removals.popleft()
if isinstance(new_data, Mapping):
for old_key, old_elem in old_data.items():
if old_key in ignore_keys or old_key.startswith('_ansible'):
new_data[old_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals)
else:
# Sanitize the old key. We take advantage of the sanitizing code in
# _remove_values_conditions() rather than recreating it here.
new_key = _remove_values_conditions(old_key, no_log_strings, None)
new_data[new_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals)
else:
for elem in old_data:
new_elem = _sanitize_keys_conditions(elem, no_log_strings, ignore_keys, deferred_removals)
if isinstance(new_data, MutableSequence):
new_data.append(new_elem)
elif isinstance(new_data, MutableSet):
new_data.add(new_elem)
else:
raise TypeError('Unknown container type encountered when removing private values from keys')
return new_value
def heuristic_log_sanitize(data, no_log_values=None): def heuristic_log_sanitize(data, no_log_values=None):
''' Remove strings that look like passwords from log messages ''' ''' Remove strings that look like passwords from log messages '''
# Currently filters: # Currently filters:
@ -661,10 +455,6 @@ def missing_required_lib(library, reason=None, url=None):
return msg return msg
class AnsibleFallbackNotFound(Exception):
pass
class AnsibleModule(object): class AnsibleModule(object):
def __init__(self, argument_spec, bypass_checks=False, no_log=False, def __init__(self, argument_spec, bypass_checks=False, no_log=False,
mutually_exclusive=None, required_together=None, mutually_exclusive=None, required_together=None,
@ -1492,21 +1282,16 @@ class AnsibleModule(object):
# this uses exceptions as it happens before we can safely call fail_json # this uses exceptions as it happens before we can safely call fail_json
alias_warnings = [] alias_warnings = []
alias_results, self._legal_inputs = handle_aliases(spec, param, alias_warnings=alias_warnings) alias_deprecations = []
alias_results, self._legal_inputs = handle_aliases(spec, param, alias_warnings, alias_deprecations)
for option, alias in alias_warnings: for option, alias in alias_warnings:
warn('Both option %s and its alias %s are set.' % (option_prefix + option, option_prefix + alias)) warn('Both option %s and its alias %s are set.' % (option_prefix + option, option_prefix + alias))
deprecated_aliases = [] for deprecation in alias_deprecations:
for i in spec.keys(): deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'],
if 'deprecated_aliases' in spec[i].keys(): version=deprecation.get('version'), date=deprecation.get('date'),
for alias in spec[i]['deprecated_aliases']: collection_name=deprecation.get('collection_name'))
deprecated_aliases.append(alias)
for deprecation in deprecated_aliases:
if deprecation['name'] in param.keys():
deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'],
version=deprecation.get('version'), date=deprecation.get('date'),
collection_name=deprecation.get('collection_name'))
return alias_results return alias_results
def _handle_no_log_values(self, spec=None, param=None): def _handle_no_log_values(self, spec=None, param=None):
@ -1818,7 +1603,6 @@ class AnsibleModule(object):
options_legal_inputs = list(spec.keys()) + list(options_aliases.keys()) options_legal_inputs = list(spec.keys()) + list(options_aliases.keys())
self._set_internal_properties(spec, param)
self._check_arguments(spec, param, options_legal_inputs) self._check_arguments(spec, param, options_legal_inputs)
# check exclusive early # check exclusive early
@ -1854,28 +1638,6 @@ class AnsibleModule(object):
return type_checker, wanted return type_checker, wanted
def _handle_elements(self, wanted, param, values):
type_checker, wanted_name = self._get_wanted_type(wanted, param)
validated_params = []
# Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {}
if wanted_name == 'str' and isinstance(wanted, string_types):
if isinstance(param, string_types):
kwargs['param'] = param
elif isinstance(param, dict):
kwargs['param'] = list(param.keys())[0]
for value in values:
try:
validated_params.append(type_checker(value, **kwargs))
except (TypeError, ValueError) as e:
msg = "Elements value for option %s" % param
if self._options_context:
msg += " found in '%s'" % " -> ".join(self._options_context)
msg += " is of type %s and we were unable to convert to %s: %s" % (type(value), wanted_name, to_native(e))
self.fail_json(msg=msg)
return validated_params
def _check_argument_types(self, spec=None, param=None, prefix=''): def _check_argument_types(self, spec=None, param=None, prefix=''):
''' ensure all arguments have the requested type ''' ''' ensure all arguments have the requested type '''
@ -1884,61 +1646,22 @@ class AnsibleModule(object):
if param is None: if param is None:
param = self.params param = self.params
for (k, v) in spec.items(): errors = []
wanted = v.get('type', None) validate_argument_types(spec, param, errors=errors)
if k not in param:
continue
value = param[k] if errors:
if value is None: self.fail_json(msg=errors[0])
continue
type_checker, wanted_name = self._get_wanted_type(wanted, k)
# Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {}
if wanted_name == 'str' and isinstance(type_checker, string_types):
kwargs['param'] = list(param.keys())[0]
# Get the name of the parent key if this is a nested option
if prefix:
kwargs['prefix'] = prefix
try:
param[k] = type_checker(value, **kwargs)
wanted_elements = v.get('elements', None)
if wanted_elements:
if wanted != 'list' or not isinstance(param[k], list):
msg = "Invalid type %s for option '%s'" % (wanted_name, param)
if self._options_context:
msg += " found in '%s'." % " -> ".join(self._options_context)
msg += ", elements value check is supported only with 'list' type"
self.fail_json(msg=msg)
param[k] = self._handle_elements(wanted_elements, k, param[k])
except (TypeError, ValueError) as e:
msg = "argument %s is of type %s" % (k, type(value))
if self._options_context:
msg += " found in '%s'." % " -> ".join(self._options_context)
msg += " and we were unable to convert to %s: %s" % (wanted_name, to_native(e))
self.fail_json(msg=msg)
def _set_defaults(self, pre=True, spec=None, param=None): def _set_defaults(self, pre=True, spec=None, param=None):
if spec is None: if spec is None:
spec = self.argument_spec spec = self.argument_spec
if param is None: if param is None:
param = self.params param = self.params
for (k, v) in spec.items():
default = v.get('default', None)
# This prevents setting defaults on required items on the 1st run, # The interface for set_defaults is different than _set_defaults()
# otherwise will set things without a default to None on the 2nd. # The third parameter controls whether or not defaults are actually set.
if k not in param and (default is not None or not pre): set_default = not pre
# Make sure any default value for no_log fields are masked. self.no_log_values.update(set_defaults(spec, param, set_default))
if v.get('no_log', False) and default:
self.no_log_values.add(default)
param[k] = default
def _set_fallbacks(self, spec=None, param=None): def _set_fallbacks(self, spec=None, param=None):
if spec is None: if spec is None:
@ -1946,25 +1669,7 @@ class AnsibleModule(object):
if param is None: if param is None:
param = self.params param = self.params
for (k, v) in spec.items(): self.no_log_values.update(set_fallbacks(spec, param))
fallback = v.get('fallback', (None,))
fallback_strategy = fallback[0]
fallback_args = []
fallback_kwargs = {}
if k not in param and fallback_strategy is not None:
for item in fallback[1:]:
if isinstance(item, dict):
fallback_kwargs = item
else:
fallback_args = item
try:
fallback_value = fallback_strategy(*fallback_args, **fallback_kwargs)
except AnsibleFallbackNotFound:
continue
else:
if v.get('no_log', False) and fallback_value:
self.no_log_values.add(fallback_value)
param[k] = fallback_value
def _load_params(self): def _load_params(self):
''' read the input and set the params attribute. ''' read the input and set the params attribute.

View file

@ -0,0 +1,134 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from copy import deepcopy
from ansible.module_utils.common._collections_compat import (
Sequence,
)
from ansible.module_utils.common.parameters import (
get_unsupported_parameters,
handle_aliases,
list_no_log_values,
remove_values,
set_defaults,
set_fallbacks,
validate_argument_types,
validate_argument_values,
validate_sub_spec,
)
from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.warnings import deprecate, warn
from ansible.module_utils.common.validation import (
check_required_arguments,
)
from ansible.module_utils.six import string_types
class ArgumentSpecValidator():
"""Argument spec validation class"""
def __init__(self, argument_spec, parameters):
self._error_messages = []
self._no_log_values = set()
self.argument_spec = argument_spec
# Make a copy of the original parameters to avoid changing them
self._validated_parameters = deepcopy(parameters)
self._unsupported_parameters = set()
@property
def error_messages(self):
return self._error_messages
@property
def validated_parameters(self):
return self._validated_parameters
def _add_error(self, error):
if isinstance(error, string_types):
self._error_messages.append(error)
elif isinstance(error, Sequence):
self._error_messages.extend(error)
else:
raise ValueError('Error messages must be a string or sequence not a %s' % type(error))
def _sanitize_error_messages(self):
self._error_messages = remove_values(self._error_messages, self._no_log_values)
def validate(self, *args, **kwargs):
"""Validate module parameters against argument spec.
:Example:
validator = ArgumentSpecValidator(argument_spec, parameters)
passeded = validator.validate()
:param argument_spec: Specification of parameters, type, and valid values
:type argument_spec: dict
:param parameters: Parameters provided to the role
:type parameters: dict
:returns: True if no errors were encountered, False if any errors were encountered.
:rtype: bool
"""
self._no_log_values.update(set_fallbacks(self.argument_spec, self._validated_parameters))
alias_warnings = []
alias_deprecations = []
try:
alias_results, legal_inputs = handle_aliases(self.argument_spec, self._validated_parameters, alias_warnings, alias_deprecations)
except (TypeError, ValueError) as e:
alias_results = {}
legal_inputs = None
self._add_error(to_native(e))
for option, alias in alias_warnings:
warn('Both option %s and its alias %s are set.' % (option, alias))
for deprecation in alias_deprecations:
deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'],
version=deprecation.get('version'), date=deprecation.get('date'),
collection_name=deprecation.get('collection_name'))
self._no_log_values.update(list_no_log_values(self.argument_spec, self._validated_parameters))
if legal_inputs is None:
legal_inputs = list(alias_results.keys()) + list(self.argument_spec.keys())
self._unsupported_parameters.update(get_unsupported_parameters(self.argument_spec, self._validated_parameters, legal_inputs))
self._no_log_values.update(set_defaults(self.argument_spec, self._validated_parameters, False))
try:
check_required_arguments(self.argument_spec, self._validated_parameters)
except TypeError as e:
self._add_error(to_native(e))
validate_argument_types(self.argument_spec, self._validated_parameters, errors=self._error_messages)
validate_argument_values(self.argument_spec, self._validated_parameters, errors=self._error_messages)
self._no_log_values.update(set_defaults(self.argument_spec, self._validated_parameters))
validate_sub_spec(self.argument_spec, self._validated_parameters,
errors=self._error_messages,
no_log_values=self._no_log_values,
unsupported_parameters=self._unsupported_parameters)
if self._unsupported_parameters:
self._add_error('Unsupported parameters: %s' % ', '.join(sorted(list(self._unsupported_parameters))))
self._sanitize_error_messages()
if self.error_messages:
return False
else:
return True

View file

@ -5,18 +5,44 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
__metaclass__ = type __metaclass__ = type
from ansible.module_utils._text import to_native import datetime
from ansible.module_utils.common._collections_compat import Mapping import os
from collections import deque
from itertools import chain
from ansible.module_utils.common.collections import is_iterable from ansible.module_utils.common.collections import is_iterable
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible.module_utils.common.text.formatters import lenient_lowercase
from ansible.module_utils.common.warnings import warn
from ansible.module_utils.parsing.convert_bool import BOOLEANS_FALSE, BOOLEANS_TRUE
from ansible.module_utils.common._collections_compat import (
KeysView,
Set,
Sequence,
Mapping,
MutableMapping,
MutableSet,
MutableSequence,
)
from ansible.module_utils.six import ( from ansible.module_utils.six import (
binary_type, binary_type,
integer_types, integer_types,
string_types, string_types,
text_type, text_type,
PY2,
PY3,
) )
from ansible.module_utils.common.validation import ( from ansible.module_utils.common.validation import (
check_mutually_exclusive,
check_required_arguments,
check_required_together,
check_required_one_of,
check_required_if,
check_required_by,
check_type_bits, check_type_bits,
check_type_bool, check_type_bool,
check_type_bytes, check_type_bytes,
@ -71,6 +97,10 @@ DEFAULT_TYPE_VALIDATORS = {
} }
class AnsibleFallbackNotFound(Exception):
pass
def _return_datastructure_name(obj): def _return_datastructure_name(obj):
""" Return native stringified values from datastructures. """ Return native stringified values from datastructures.
@ -96,11 +126,211 @@ def _return_datastructure_name(obj):
raise TypeError('Unknown parameter type: %s' % (type(obj))) raise TypeError('Unknown parameter type: %s' % (type(obj)))
def _remove_values_conditions(value, no_log_strings, deferred_removals):
"""
Helper function for :meth:`remove_values`.
:arg value: The value to check for strings that need to be stripped
:arg no_log_strings: set of strings which must be stripped out of any values
:arg deferred_removals: List which holds information about nested
containers that have to be iterated for removals. It is passed into
this function so that more entries can be added to it if value is
a container type. The format of each entry is a 2-tuple where the first
element is the ``value`` parameter and the second value is a new
container to copy the elements of ``value`` into once iterated.
:returns: if ``value`` is a scalar, returns ``value`` with two exceptions:
1. :class:`~datetime.datetime` objects which are changed into a string representation.
2. objects which are in no_log_strings are replaced with a placeholder
so that no sensitive data is leaked.
If ``value`` is a container type, returns a new empty container.
``deferred_removals`` is added to as a side-effect of this function.
.. warning:: It is up to the caller to make sure the order in which value
is passed in is correct. For instance, higher level containers need
to be passed in before lower level containers. For example, given
``{'level1': {'level2': 'level3': [True]} }`` first pass in the
dictionary for ``level1``, then the dict for ``level2``, and finally
the list for ``level3``.
"""
if isinstance(value, (text_type, binary_type)):
# Need native str type
native_str_value = value
if isinstance(value, text_type):
value_is_text = True
if PY2:
native_str_value = to_bytes(value, errors='surrogate_or_strict')
elif isinstance(value, binary_type):
value_is_text = False
if PY3:
native_str_value = to_text(value, errors='surrogate_or_strict')
if native_str_value in no_log_strings:
return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER'
for omit_me in no_log_strings:
native_str_value = native_str_value.replace(omit_me, '*' * 8)
if value_is_text and isinstance(native_str_value, binary_type):
value = to_text(native_str_value, encoding='utf-8', errors='surrogate_then_replace')
elif not value_is_text and isinstance(native_str_value, text_type):
value = to_bytes(native_str_value, encoding='utf-8', errors='surrogate_then_replace')
else:
value = native_str_value
elif isinstance(value, Sequence):
if isinstance(value, MutableSequence):
new_value = type(value)()
else:
new_value = [] # Need a mutable value
deferred_removals.append((value, new_value))
value = new_value
elif isinstance(value, Set):
if isinstance(value, MutableSet):
new_value = type(value)()
else:
new_value = set() # Need a mutable value
deferred_removals.append((value, new_value))
value = new_value
elif isinstance(value, Mapping):
if isinstance(value, MutableMapping):
new_value = type(value)()
else:
new_value = {} # Need a mutable value
deferred_removals.append((value, new_value))
value = new_value
elif isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))):
stringy_value = to_native(value, encoding='utf-8', errors='surrogate_or_strict')
if stringy_value in no_log_strings:
return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER'
for omit_me in no_log_strings:
if omit_me in stringy_value:
return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER'
elif isinstance(value, (datetime.datetime, datetime.date)):
value = value.isoformat()
else:
raise TypeError('Value of unknown type: %s, %s' % (type(value), value))
return value
def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_removals):
""" Helper method to sanitize_keys() to build deferred_removals and avoid deep recursion. """
if isinstance(value, (text_type, binary_type)):
return value
if isinstance(value, Sequence):
if isinstance(value, MutableSequence):
new_value = type(value)()
else:
new_value = [] # Need a mutable value
deferred_removals.append((value, new_value))
return new_value
if isinstance(value, Set):
if isinstance(value, MutableSet):
new_value = type(value)()
else:
new_value = set() # Need a mutable value
deferred_removals.append((value, new_value))
return new_value
if isinstance(value, Mapping):
if isinstance(value, MutableMapping):
new_value = type(value)()
else:
new_value = {} # Need a mutable value
deferred_removals.append((value, new_value))
return new_value
if isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))):
return value
if isinstance(value, (datetime.datetime, datetime.date)):
return value
raise TypeError('Value of unknown type: %s, %s' % (type(value), value))
def env_fallback(*args, **kwargs):
"""Load value from environment variable"""
for arg in args:
if arg in os.environ:
return os.environ[arg]
raise AnsibleFallbackNotFound
def set_fallbacks(argument_spec, parameters):
no_log_values = set()
for param, value in argument_spec.items():
fallback = value.get('fallback', (None,))
fallback_strategy = fallback[0]
fallback_args = []
fallback_kwargs = {}
if param not in parameters and fallback_strategy is not None:
for item in fallback[1:]:
if isinstance(item, dict):
fallback_kwargs = item
else:
fallback_args = item
try:
fallback_value = fallback_strategy(*fallback_args, **fallback_kwargs)
except AnsibleFallbackNotFound:
continue
else:
if value.get('no_log', False) and fallback_value:
no_log_values.add(fallback_value)
parameters[param] = fallback_value
return no_log_values
def set_defaults(argument_spec, parameters, set_default=True):
"""Set default values for parameters when no value is supplied.
Modifies parameters directly.
:param argument_spec: Argument spec
:type argument_spec: dict
:param parameters: Parameters to evaluate
:type parameters: dict
:param set_default: Whether or not to set the default values
:type set_default: bool
:returns: Set of strings that should not be logged.
:rtype: set
"""
no_log_values = set()
for param, value in argument_spec.items():
# TODO: Change the default value from None to Sentinel to differentiate between
# user supplied None and a default value set by this function.
default = value.get('default', None)
# This prevents setting defaults on required items on the 1st run,
# otherwise will set things without a default to None on the 2nd.
if param not in parameters and (default is not None or set_default):
# Make sure any default value for no_log fields are masked.
if value.get('no_log', False) and default:
no_log_values.add(default)
parameters[param] = default
return no_log_values
def list_no_log_values(argument_spec, params): def list_no_log_values(argument_spec, params):
"""Return set of no log values """Return set of no log values
:arg argument_spec: An argument spec dictionary from a module :arg argument_spec: An argument spec dictionary from a module
:arg params: Dictionary of all module parameters :arg params: Dictionary of all parameters
:returns: Set of strings that should be hidden from output:: :returns: Set of strings that should be hidden from output::
@ -146,11 +376,11 @@ def list_no_log_values(argument_spec, params):
return no_log_values return no_log_values
def list_deprecations(argument_spec, params, prefix=''): def list_deprecations(argument_spec, parameters, prefix=''):
"""Return a list of deprecations """Return a list of deprecations
:arg argument_spec: An argument spec dictionary from a module :arg argument_spec: An argument spec dictionary from a module
:arg params: Dictionary of all module parameters :arg parameters: Dictionary of parameters
:returns: List of dictionaries containing a message and version in which :returns: List of dictionaries containing a message and version in which
the deprecated parameter will be removed, or an empty list:: the deprecated parameter will be removed, or an empty list::
@ -160,7 +390,7 @@ def list_deprecations(argument_spec, params, prefix=''):
deprecations = [] deprecations = []
for arg_name, arg_opts in argument_spec.items(): for arg_name, arg_opts in argument_spec.items():
if arg_name in params: if arg_name in parameters:
if prefix: if prefix:
sub_prefix = '%s["%s"]' % (prefix, arg_name) sub_prefix = '%s["%s"]' % (prefix, arg_name)
else: else:
@ -180,7 +410,7 @@ def list_deprecations(argument_spec, params, prefix=''):
# Check sub-argument spec # Check sub-argument spec
sub_argument_spec = arg_opts.get('options') sub_argument_spec = arg_opts.get('options')
if sub_argument_spec is not None: if sub_argument_spec is not None:
sub_arguments = params[arg_name] sub_arguments = parameters[arg_name]
if isinstance(sub_arguments, Mapping): if isinstance(sub_arguments, Mapping):
sub_arguments = [sub_arguments] sub_arguments = [sub_arguments]
if isinstance(sub_arguments, list): if isinstance(sub_arguments, list):
@ -191,12 +421,94 @@ def list_deprecations(argument_spec, params, prefix=''):
return deprecations return deprecations
def handle_aliases(argument_spec, params, alias_warnings=None): def sanitize_keys(obj, no_log_strings, ignore_keys=frozenset()):
""" Sanitize the keys in a container object by removing no_log values from key names.
This is a companion function to the `remove_values()` function. Similar to that function,
we make use of deferred_removals to avoid hitting maximum recursion depth in cases of
large data structures.
:param obj: The container object to sanitize. Non-container objects are returned unmodified.
:param no_log_strings: A set of string values we do not want logged.
:param ignore_keys: A set of string values of keys to not sanitize.
:returns: An object with sanitized keys.
"""
deferred_removals = deque()
no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings]
new_value = _sanitize_keys_conditions(obj, no_log_strings, ignore_keys, deferred_removals)
while deferred_removals:
old_data, new_data = deferred_removals.popleft()
if isinstance(new_data, Mapping):
for old_key, old_elem in old_data.items():
if old_key in ignore_keys or old_key.startswith('_ansible'):
new_data[old_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals)
else:
# Sanitize the old key. We take advantage of the sanitizing code in
# _remove_values_conditions() rather than recreating it here.
new_key = _remove_values_conditions(old_key, no_log_strings, None)
new_data[new_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals)
else:
for elem in old_data:
new_elem = _sanitize_keys_conditions(elem, no_log_strings, ignore_keys, deferred_removals)
if isinstance(new_data, MutableSequence):
new_data.append(new_elem)
elif isinstance(new_data, MutableSet):
new_data.add(new_elem)
else:
raise TypeError('Unknown container type encountered when removing private values from keys')
return new_value
def remove_values(value, no_log_strings):
""" Remove strings in no_log_strings from value. If value is a container
type, then remove a lot more.
Use of deferred_removals exists, rather than a pure recursive solution,
because of the potential to hit the maximum recursion depth when dealing with
large amounts of data (see issue #24560).
"""
deferred_removals = deque()
no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings]
new_value = _remove_values_conditions(value, no_log_strings, deferred_removals)
while deferred_removals:
old_data, new_data = deferred_removals.popleft()
if isinstance(new_data, Mapping):
for old_key, old_elem in old_data.items():
new_elem = _remove_values_conditions(old_elem, no_log_strings, deferred_removals)
new_data[old_key] = new_elem
else:
for elem in old_data:
new_elem = _remove_values_conditions(elem, no_log_strings, deferred_removals)
if isinstance(new_data, MutableSequence):
new_data.append(new_elem)
elif isinstance(new_data, MutableSet):
new_data.add(new_elem)
else:
raise TypeError('Unknown container type encountered when removing private values from output')
return new_value
def handle_aliases(argument_spec, parameters, alias_warnings=None, alias_deprecations=None):
"""Return a two item tuple. The first is a dictionary of aliases, the second is """Return a two item tuple. The first is a dictionary of aliases, the second is
a list of legal inputs. a list of legal inputs.
Modify supplied parameters by adding a new key for each alias.
If a list is provided to the alias_warnings parameter, it will be filled with tuples If a list is provided to the alias_warnings parameter, it will be filled with tuples
(option, alias) in every case where both an option and its alias are specified. (option, alias) in every case where both an option and its alias are specified.
If a list is provided to alias_deprecations, it will be populated with dictionaries,
each containing deprecation information for each alias found in argument_spec.
""" """
legal_inputs = ['_ansible_%s' % k for k in PASS_VARS] legal_inputs = ['_ansible_%s' % k for k in PASS_VARS]
@ -207,31 +519,40 @@ def handle_aliases(argument_spec, params, alias_warnings=None):
aliases = v.get('aliases', None) aliases = v.get('aliases', None)
default = v.get('default', None) default = v.get('default', None)
required = v.get('required', False) required = v.get('required', False)
if alias_deprecations is not None:
for alias in argument_spec[k].get('deprecated_aliases', []):
if alias.get('name') in parameters:
alias_deprecations.append(alias)
if default is not None and required: if default is not None and required:
# not alias specific but this is a good place to check this # not alias specific but this is a good place to check this
raise ValueError("internal error: required and default are mutually exclusive for %s" % k) raise ValueError("internal error: required and default are mutually exclusive for %s" % k)
if aliases is None: if aliases is None:
continue continue
if not is_iterable(aliases) or isinstance(aliases, (binary_type, text_type)): if not is_iterable(aliases) or isinstance(aliases, (binary_type, text_type)):
raise TypeError('internal error: aliases must be a list or tuple') raise TypeError('internal error: aliases must be a list or tuple')
for alias in aliases: for alias in aliases:
legal_inputs.append(alias) legal_inputs.append(alias)
aliases_results[alias] = k aliases_results[alias] = k
if alias in params: if alias in parameters:
if k in params and alias_warnings is not None: if k in parameters and alias_warnings is not None:
alias_warnings.append((k, alias)) alias_warnings.append((k, alias))
params[k] = params[alias] parameters[k] = parameters[alias]
return aliases_results, legal_inputs return aliases_results, legal_inputs
def get_unsupported_parameters(argument_spec, module_parameters, legal_inputs=None): def get_unsupported_parameters(argument_spec, parameters, legal_inputs=None):
"""Check keys in module_parameters against those provided in legal_inputs """Check keys in parameters against those provided in legal_inputs
to ensure they contain legal values. If legal_inputs are not supplied, to ensure they contain legal values. If legal_inputs are not supplied,
they will be generated using the argument_spec. they will be generated using the argument_spec.
:arg argument_spec: Dictionary of parameters, their type, and valid values. :arg argument_spec: Dictionary of parameters, their type, and valid values.
:arg module_parameters: Dictionary of module parameters. :arg parameters: Dictionary of parameters.
:arg legal_inputs: List of valid key names property names. Overrides values :arg legal_inputs: List of valid key names property names. Overrides values
in argument_spec. in argument_spec.
@ -240,10 +561,10 @@ def get_unsupported_parameters(argument_spec, module_parameters, legal_inputs=No
""" """
if legal_inputs is None: if legal_inputs is None:
aliases, legal_inputs = handle_aliases(argument_spec, module_parameters) aliases, legal_inputs = handle_aliases(argument_spec, parameters)
unsupported_parameters = set() unsupported_parameters = set()
for k in module_parameters.keys(): for k in parameters.keys():
if k not in legal_inputs: if k not in legal_inputs:
unsupported_parameters.add(k) unsupported_parameters.add(k)
@ -275,3 +596,256 @@ def get_type_validator(wanted):
wanted = getattr(wanted, '__name__', to_native(type(wanted))) wanted = getattr(wanted, '__name__', to_native(type(wanted)))
return type_checker, wanted return type_checker, wanted
def validate_elements(wanted_type, parameter, values, options_context=None, errors=None):
if errors is None:
errors = []
type_checker, wanted_element_type = get_type_validator(wanted_type)
validated_parameters = []
# Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {}
if wanted_element_type == 'str' and isinstance(wanted_type, string_types):
if isinstance(parameter, string_types):
kwargs['param'] = parameter
elif isinstance(parameter, dict):
kwargs['param'] = list(parameter.keys())[0]
for value in values:
try:
validated_parameters.append(type_checker(value, **kwargs))
except (TypeError, ValueError) as e:
msg = "Elements value for option '%s'" % parameter
if options_context:
msg += " found in '%s'" % " -> ".join(options_context)
msg += " is of type %s and we were unable to convert to %s: %s" % (type(value), wanted_element_type, to_native(e))
errors.append(msg)
return validated_parameters
def validate_argument_types(argument_spec, parameters, prefix='', options_context=None, errors=None):
"""Validate that parameter types match the type in the argument spec.
Determine the appropriate type checker function and run each
parameter value through that function. All error messages from type checker
functions are returned. If any parameter fails to validate, it will not
be in the returned parameters.
:param argument_spec: Argument spec
:type argument_spec: dict
:param parameters: Parameters passed to module
:type parameters: dict
:param prefix: Name of the parent key that contains the spec. Used in the error message
:type prefix: str
:param options_context: List of contexts?
:type options_context: list
:returns: Two item tuple containing validated and coerced parameters
and a list of any errors that were encountered.
:rtype: tuple
"""
if errors is None:
errors = []
for param, spec in argument_spec.items():
if param not in parameters:
continue
value = parameters[param]
if value is None:
continue
wanted_type = spec.get('type')
type_checker, wanted_name = get_type_validator(wanted_type)
# Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {}
if wanted_name == 'str' and isinstance(wanted_type, string_types):
kwargs['param'] = list(parameters.keys())[0]
# Get the name of the parent key if this is a nested option
if prefix:
kwargs['prefix'] = prefix
try:
parameters[param] = type_checker(value, **kwargs)
elements_wanted_type = spec.get('elements', None)
if elements_wanted_type:
elements = parameters[param]
if wanted_type != 'list' or not isinstance(elements, list):
msg = "Invalid type %s for option '%s'" % (wanted_name, elements)
if options_context:
msg += " found in '%s'." % " -> ".join(options_context)
msg += ", elements value check is supported only with 'list' type"
errors.append(msg)
parameters[param] = validate_elements(elements_wanted_type, param, elements, options_context, errors)
except (TypeError, ValueError) as e:
msg = "argument '%s' is of type %s" % (param, type(value))
if options_context:
msg += " found in '%s'." % " -> ".join(options_context)
msg += " and we were unable to convert to %s: %s" % (wanted_name, to_native(e))
errors.append(msg)
def validate_argument_values(argument_spec, parameters, options_context=None, errors=None):
"""Ensure all arguments have the requested values, and there are no stray arguments"""
if errors is None:
errors = []
for param, spec in argument_spec.items():
choices = spec.get('choices')
if choices is None:
continue
if isinstance(choices, (frozenset, KeysView, Sequence)) and not isinstance(choices, (binary_type, text_type)):
if param in parameters:
# Allow one or more when type='list' param with choices
if isinstance(parameters[param], list):
diff_list = ", ".join([item for item in parameters[param] if item not in choices])
if diff_list:
choices_str = ", ".join([to_native(c) for c in choices])
msg = "value of %s must be one or more of: %s. Got no match for: %s" % (param, choices_str, diff_list)
if options_context:
msg += " found in %s" % " -> ".join(options_context)
errors.append(msg)
elif parameters[param] not in choices:
# PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking
# the value. If we can't figure this out, module author is responsible.
lowered_choices = None
if parameters[param] == 'False':
lowered_choices = lenient_lowercase(choices)
overlap = BOOLEANS_FALSE.intersection(choices)
if len(overlap) == 1:
# Extract from a set
(parameters[param],) = overlap
if parameters[param] == 'True':
if lowered_choices is None:
lowered_choices = lenient_lowercase(choices)
overlap = BOOLEANS_TRUE.intersection(choices)
if len(overlap) == 1:
(parameters[param],) = overlap
if parameters[param] not in choices:
choices_str = ", ".join([to_native(c) for c in choices])
msg = "value of %s must be one of: %s, got: %s" % (param, choices_str, parameters[param])
if options_context:
msg += " found in %s" % " -> ".join(options_context)
errors.append(msg)
else:
msg = "internal error: choices for argument %s are not iterable: %s" % (param, choices)
if options_context:
msg += " found in %s" % " -> ".join(options_context)
errors.append(msg)
def validate_sub_spec(argument_spec, parameters, prefix='', options_context=None, errors=None, no_log_values=None, unsupported_parameters=None):
"""Validate sub argument spec. This function is recursive."""
if options_context is None:
options_context = []
if errors is None:
errors = []
if no_log_values is None:
no_log_values = set()
if unsupported_parameters is None:
unsupported_parameters = set()
for param, value in argument_spec.items():
wanted = value.get('type')
if wanted == 'dict' or (wanted == 'list' and value.get('elements', '') == dict):
sub_spec = value.get('options')
if value.get('apply_defaults', False):
if sub_spec is not None:
if parameters.get(value) is None:
parameters[param] = {}
else:
continue
elif sub_spec is None or param not in parameters or parameters[param] is None:
continue
# Keep track of context for warning messages
options_context.append(param)
# Make sure we can iterate over the elements
if isinstance(parameters[param], dict):
elements = [parameters[param]]
else:
elements = parameters[param]
for idx, sub_parameters in enumerate(elements):
if not isinstance(sub_parameters, dict):
errors.append("value of '%s' must be of type dict or list of dicts" % param)
# Set prefix for warning messages
new_prefix = prefix + param
if wanted == 'list':
new_prefix += '[%d]' % idx
new_prefix += '.'
no_log_values.update(set_fallbacks(sub_spec, sub_parameters))
alias_warnings = []
try:
options_aliases, legal_inputs = handle_aliases(sub_spec, sub_parameters, alias_warnings)
except (TypeError, ValueError) as e:
options_aliases = {}
legal_inputs = None
errors.append(to_native(e))
for option, alias in alias_warnings:
warn('Both option %s and its alias %s are set.' % (option, alias))
no_log_values.update(list_no_log_values(sub_spec, sub_parameters))
if legal_inputs is None:
legal_inputs = list(options_aliases.keys()) + list(sub_spec.keys())
unsupported_parameters.update(get_unsupported_parameters(sub_spec, sub_parameters, legal_inputs))
try:
check_mutually_exclusive(value.get('mutually_exclusive'), sub_parameters)
except TypeError as e:
errors.append(to_native(e))
no_log_values.update(set_defaults(sub_spec, sub_parameters, False))
try:
check_required_arguments(sub_spec, sub_parameters)
except TypeError as e:
errors.append(to_native(e))
validate_argument_types(sub_spec, sub_parameters, new_prefix, options_context, errors=errors)
validate_argument_values(sub_spec, sub_parameters, options_context, errors=errors)
checks = [
(check_required_together, 'required_together'),
(check_required_one_of, 'required_one_of'),
(check_required_if, 'required_if'),
(check_required_by, 'required_by'),
]
for check in checks:
try:
check[0](value.get(check[1]), parameters)
except TypeError as e:
errors.append(to_native(e))
no_log_values.update(set_defaults(sub_spec, sub_parameters))
# Handle nested specs
validate_sub_spec(sub_spec, sub_parameters, new_prefix, options_context, errors, no_log_values, unsupported_parameters)
options_context.pop()

View file

@ -23,11 +23,11 @@ from ansible.module_utils.six import (
) )
def count_terms(terms, module_parameters): def count_terms(terms, parameters):
"""Count the number of occurrences of a key in a given dictionary """Count the number of occurrences of a key in a given dictionary
:arg terms: String or iterable of values to check :arg terms: String or iterable of values to check
:arg module_parameters: Dictionary of module parameters :arg parameters: Dictionary of parameters
:returns: An integer that is the number of occurrences of the terms values :returns: An integer that is the number of occurrences of the terms values
in the provided dictionary. in the provided dictionary.
@ -36,17 +36,17 @@ def count_terms(terms, module_parameters):
if not is_iterable(terms): if not is_iterable(terms):
terms = [terms] terms = [terms]
return len(set(terms).intersection(module_parameters)) return len(set(terms).intersection(parameters))
def check_mutually_exclusive(terms, module_parameters): def check_mutually_exclusive(terms, parameters):
"""Check mutually exclusive terms against argument parameters """Check mutually exclusive terms against argument parameters
Accepts a single list or list of lists that are groups of terms that should be Accepts a single list or list of lists that are groups of terms that should be
mutually exclusive with one another mutually exclusive with one another
:arg terms: List of mutually exclusive module parameters :arg terms: List of mutually exclusive parameters
:arg module_parameters: Dictionary of module parameters :arg parameters: Dictionary of parameters
:returns: Empty list or raises TypeError if the check fails. :returns: Empty list or raises TypeError if the check fails.
""" """
@ -56,7 +56,7 @@ def check_mutually_exclusive(terms, module_parameters):
return results return results
for check in terms: for check in terms:
count = count_terms(check, module_parameters) count = count_terms(check, parameters)
if count > 1: if count > 1:
results.append(check) results.append(check)
@ -68,7 +68,7 @@ def check_mutually_exclusive(terms, module_parameters):
return results return results
def check_required_one_of(terms, module_parameters): def check_required_one_of(terms, parameters):
"""Check each list of terms to ensure at least one exists in the given module """Check each list of terms to ensure at least one exists in the given module
parameters parameters
@ -76,7 +76,7 @@ def check_required_one_of(terms, module_parameters):
:arg terms: List of lists of terms to check. For each list of terms, at :arg terms: List of lists of terms to check. For each list of terms, at
least one is required. least one is required.
:arg module_parameters: Dictionary of module parameters :arg parameters: Dictionary of parameters
:returns: Empty list or raises TypeError if the check fails. :returns: Empty list or raises TypeError if the check fails.
""" """
@ -86,7 +86,7 @@ def check_required_one_of(terms, module_parameters):
return results return results
for term in terms: for term in terms:
count = count_terms(term, module_parameters) count = count_terms(term, parameters)
if count == 0: if count == 0:
results.append(term) results.append(term)
@ -98,16 +98,16 @@ def check_required_one_of(terms, module_parameters):
return results return results
def check_required_together(terms, module_parameters): def check_required_together(terms, parameters):
"""Check each list of terms to ensure every parameter in each list exists """Check each list of terms to ensure every parameter in each list exists
in the given module parameters in the given parameters
Accepts a list of lists or tuples Accepts a list of lists or tuples
:arg terms: List of lists of terms to check. Each list should include :arg terms: List of lists of terms to check. Each list should include
parameters that are all required when at least one is specified parameters that are all required when at least one is specified
in the module_parameters. in the parameters.
:arg module_parameters: Dictionary of module parameters :arg parameters: Dictionary of parameters
:returns: Empty list or raises TypeError if the check fails. :returns: Empty list or raises TypeError if the check fails.
""" """
@ -117,7 +117,7 @@ def check_required_together(terms, module_parameters):
return results return results
for term in terms: for term in terms:
counts = [count_terms(field, module_parameters) for field in term] counts = [count_terms(field, parameters) for field in term]
non_zero = [c for c in counts if c > 0] non_zero = [c for c in counts if c > 0]
if len(non_zero) > 0: if len(non_zero) > 0:
if 0 in counts: if 0 in counts:
@ -130,14 +130,14 @@ def check_required_together(terms, module_parameters):
return results return results
def check_required_by(requirements, module_parameters): def check_required_by(requirements, parameters):
"""For each key in requirements, check the corresponding list to see if they """For each key in requirements, check the corresponding list to see if they
exist in module_parameters exist in parameters
Accepts a single string or list of values for each key Accepts a single string or list of values for each key
:arg requirements: Dictionary of requirements :arg requirements: Dictionary of requirements
:arg module_parameters: Dictionary of module parameters :arg parameters: Dictionary of parameters
:returns: Empty dictionary or raises TypeError if the :returns: Empty dictionary or raises TypeError if the
""" """
@ -147,14 +147,14 @@ def check_required_by(requirements, module_parameters):
return result return result
for (key, value) in requirements.items(): for (key, value) in requirements.items():
if key not in module_parameters or module_parameters[key] is None: if key not in parameters or parameters[key] is None:
continue continue
result[key] = [] result[key] = []
# Support strings (single-item lists) # Support strings (single-item lists)
if isinstance(value, string_types): if isinstance(value, string_types):
value = [value] value = [value]
for required in value: for required in value:
if required not in module_parameters or module_parameters[required] is None: if required not in parameters or parameters[required] is None:
result[key].append(required) result[key].append(required)
if result: if result:
@ -166,15 +166,15 @@ def check_required_by(requirements, module_parameters):
return result return result
def check_required_arguments(argument_spec, module_parameters): def check_required_arguments(argument_spec, parameters):
"""Check all paramaters in argument_spec and return a list of parameters """Check all paramaters in argument_spec and return a list of parameters
that are required but not present in module_parameters that are required but not present in parameters
Raises TypeError if the check fails Raises TypeError if the check fails
:arg argument_spec: Argument spec dicitionary containing all parameters :arg argument_spec: Argument spec dicitionary containing all parameters
and their specification and their specification
:arg module_paramaters: Dictionary of module parameters :arg module_paramaters: Dictionary of parameters
:returns: Empty list or raises TypeError if the check fails. :returns: Empty list or raises TypeError if the check fails.
""" """
@ -185,7 +185,7 @@ def check_required_arguments(argument_spec, module_parameters):
for (k, v) in argument_spec.items(): for (k, v) in argument_spec.items():
required = v.get('required', False) required = v.get('required', False)
if required and k not in module_parameters: if required and k not in parameters:
missing.append(k) missing.append(k)
if missing: if missing:
@ -195,7 +195,7 @@ def check_required_arguments(argument_spec, module_parameters):
return missing return missing
def check_required_if(requirements, module_parameters): def check_required_if(requirements, parameters):
"""Check parameters that are conditionally required """Check parameters that are conditionally required
Raises TypeError if the check fails Raises TypeError if the check fails
@ -210,7 +210,7 @@ def check_required_if(requirements, module_parameters):
['someint', 99, ('bool_param', 'string_param')], ['someint', 99, ('bool_param', 'string_param')],
] ]
:arg module_paramaters: Dictionary of module parameters :arg module_paramaters: Dictionary of parameters
:returns: Empty list or raises TypeError if the check fails. :returns: Empty list or raises TypeError if the check fails.
The results attribute of the exception contains a list of dictionaries. The results attribute of the exception contains a list of dictionaries.
@ -257,9 +257,9 @@ def check_required_if(requirements, module_parameters):
else: else:
missing['requires'] = 'all' missing['requires'] = 'all'
if key in module_parameters and module_parameters[key] == val: if key in parameters and parameters[key] == val:
for check in requirements: for check in requirements:
count = count_terms(check, module_parameters) count = count_terms(check, parameters)
if count == 0: if count == 0:
missing['missing'].append(check) missing['missing'].append(check)
if len(missing['missing']) and len(missing['missing']) >= max_missing_count: if len(missing['missing']) and len(missing['missing']) >= max_missing_count:
@ -277,13 +277,13 @@ def check_required_if(requirements, module_parameters):
return results return results
def check_missing_parameters(module_parameters, required_parameters=None): def check_missing_parameters(parameters, required_parameters=None):
"""This is for checking for required params when we can not check via """This is for checking for required params when we can not check via
argspec because we need more information than is simply given in the argspec. argspec because we need more information than is simply given in the argspec.
Raises TypeError if any required parameters are missing Raises TypeError if any required parameters are missing
:arg module_paramaters: Dictionary of module parameters :arg module_paramaters: Dictionary of parameters
:arg required_parameters: List of parameters to look for in the given module :arg required_parameters: List of parameters to look for in the given module
parameters parameters
@ -294,7 +294,7 @@ def check_missing_parameters(module_parameters, required_parameters=None):
return missing_params return missing_params
for param in required_parameters: for param in required_parameters:
if not module_parameters.get(param): if not parameters.get(param):
missing_params.append(param) missing_params.append(param)
if missing_params: if missing_params:
@ -332,7 +332,10 @@ def safe_eval(value, locals=None, include_exceptions=False):
return value return value
def check_type_str(value, allow_conversion=True): # FIXME: The param and prefix parameters here are coming from AnsibleModule._check_type_string()
# which is using those for the warning messaged based on string conversion warning settings.
# Not sure how to deal with that here since we don't have config state to query.
def check_type_str(value, allow_conversion=True, param=None, prefix=''):
"""Verify that the value is a string or convert to a string. """Verify that the value is a string or convert to a string.
Since unexpected changes can sometimes happen when converting to a string, Since unexpected changes can sometimes happen when converting to a string,

View file

@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
def test_sub_spec():
arg_spec = {
'state': {},
'user': {
'type': 'dict',
'options': {
'first': {'no_log': True},
'last': {},
'age': {'type': 'int'},
}
}
}
parameters = {
'state': 'present',
'user': {
'first': 'Rey',
'last': 'Skywalker',
'age': '19',
}
}
expected = {
'state': 'present',
'user': {
'first': 'Rey',
'last': 'Skywalker',
'age': 19,
}
}
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
assert passed is True
assert v.error_messages == []
assert v.validated_parameters == expected
def test_nested_sub_spec():
arg_spec = {
'type': {},
'car': {
'type': 'dict',
'options': {
'make': {},
'model': {},
'customizations': {
'type': 'dict',
'options': {
'engine': {},
'transmission': {},
'color': {},
'max_rpm': {'type': 'int'},
}
}
}
}
}
parameters = {
'type': 'endurance',
'car': {
'make': 'Ford',
'model': 'GT-40',
'customizations': {
'engine': '7.0 L',
'transmission': '5-speed',
'color': 'Ford blue',
'max_rpm': '6000',
}
}
}
expected = {
'type': 'endurance',
'car': {
'make': 'Ford',
'model': 'GT-40',
'customizations': {
'engine': '7.0 L',
'transmission': '5-speed',
'color': 'Ford blue',
'max_rpm': 6000,
}
}
}
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
assert passed is True
assert v.error_messages == []
assert v.validated_parameters == expected

View file

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
from ansible.module_utils.common.warnings import get_deprecation_messages
def test_spec_with_aliases():
arg_spec = {
'path': {'aliases': ['dir', 'directory']}
}
parameters = {
'dir': '/tmp',
'directory': '/tmp',
}
expected = {
'dir': '/tmp',
'directory': '/tmp',
'path': '/tmp',
}
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
assert passed is True
assert v.validated_parameters == expected
def test_alias_deprecation():
arg_spec = {
'path': {
'aliases': ['not_yo_path'],
'deprecated_aliases': [{
'name': 'not_yo_path',
'version': '1.7',
}]
}
}
parameters = {
'not_yo_path': '/tmp',
}
expected = {
'path': '/tmp',
'not_yo_path': '/tmp',
}
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
assert passed is True
assert v.validated_parameters == expected
assert v.error_messages == []
assert "Alias 'not_yo_path' is deprecated." in get_deprecation_messages()[0]['msg']

View file

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
def test_basic_spec():
arg_spec = {
'param_str': {'type': 'str'},
'param_list': {'type': 'list'},
'param_dict': {'type': 'dict'},
'param_bool': {'type': 'bool'},
'param_int': {'type': 'int'},
'param_float': {'type': 'float'},
'param_path': {'type': 'path'},
'param_raw': {'type': 'raw'},
'param_bytes': {'type': 'bytes'},
'param_bits': {'type': 'bits'},
}
parameters = {
'param_str': 22,
'param_list': 'one,two,three',
'param_dict': 'first=star,last=lord',
'param_bool': True,
'param_int': 22,
'param_float': 1.5,
'param_path': '/tmp',
'param_raw': 'raw',
'param_bytes': '2K',
'param_bits': '1Mb',
}
expected = {
'param_str': '22',
'param_list': ['one', 'two', 'three'],
'param_dict': {'first': 'star', 'last': 'lord'},
'param_bool': True,
'param_float': 1.5,
'param_int': 22,
'param_path': '/tmp',
'param_raw': 'raw',
'param_bits': 1048576,
'param_bytes': 2048,
}
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
assert passed is True
assert v.validated_parameters == expected
assert v.error_messages == []
def test_spec_with_defaults():
arg_spec = {
'param_str': {'type': 'str', 'default': 'DEFAULT'},
}
parameters = {}
expected = {
'param_str': 'DEFAULT',
}
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
assert passed is True
assert v.validated_parameters == expected
assert v.error_messages == []
def test_spec_with_elements():
arg_spec = {
'param_list': {
'type': 'list',
'elements': 'int',
}
}
parameters = {
'param_list': [55, 33, 34, '22'],
}
expected = {
'param_list': [55, 33, 34, 22],
}
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
assert passed is True
assert v.error_messages == []
assert v.validated_parameters == expected

View file

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
def test_required_and_default():
arg_spec = {
'param_req': {'required': True, 'default': 'DEFAULT'},
}
v = ArgumentSpecValidator(arg_spec, {})
passed = v.validate()
expected = {
'param_req': 'DEFAULT'
}
expected_errors = [
'internal error: required and default are mutually exclusive for param_req',
]
assert passed is False
assert v.validated_parameters == expected
assert v.error_messages == expected_errors