Use ArgumentSpecValidator in AnsibleModule (#73703)

* Begin using ArgumentSpecValidator in AnsibleModule

* Add check parameters to ArgumentSpecValidator

Add additional parameters for specifying required and mutually exclusive parameters.
Add code to the .validate() method that runs these additional checks.

* Make errors related to unsupported parameters match existing behavior

Update the punctuation in the message slightly to make it more readable.
Add a property to ArgumentSpecValidator to hold valid parameter names.

* Set default values after performining checks

* FIx sanity test failure

* Use correct parameters when checking sub options

* Use a dict when iterating over check functions

Referencing by key names makes things a bit more readable IMO.

* Fix bug in comparison for sub options evaluation

* Add options_context to check functions

This allows the parent parameter to be added the the error message if a validation
error occurs in a sub option.

* Fix bug in apply_defaults behavior of sub spec validation

* Accept options_conext in get_unsupported_parameters()

If options_context is supplied, a tuple of parent key names of unsupported parameter will be
created. This allows the full "path" to the unsupported parameter to be reported.

* Build path to the unsupported parameter for error messages.

* Remove unused import

* Update recursive finder test

* Skip if running in check mode

This was done in the _check_arguments() method. That was moved to a function that has no
way of calling fail_json(), so it must be done outside of validation.

This is a silght change in behavior, but I believe the correct one.

Previously, only unsupported parameters would cause a failure. All other checks would not be executed
if the modlue did not support check mode. This would hide validation failures in check mode.

* The great purge

Remove all methods related to argument spec validation from AnsibleModule

* Keep _name and kind in the caller and out of the validator

This seems a bit awkward since this means the caller could end up with {name} and {kind} in
the error message if they don't run the messages through the .format() method
with name and kind parameters.

* Double moustaches work

I wasn't sure if they get stripped or not. Looks like they do. Neat trick.

* Add changelog

* Update unsupported parameter test

The error message changed to include name and kind.

* Remove unused import

* Add better documentation for ArgumentSpecValidator class

* Fix example

* Few more docs fixes

* Mark required and mutually exclusive attributes as private

* Mark validate functions as private

* Reorganize functions in validation.py

* Remove unused imports in basic.py related to argument spec validation

* Create errors is module_utils

We have errors in lib/ansible/errors/ but those cannot be used by modules.

* Update recursive finder test

* Move errors to file rather than __init__.py

* Change ArgumentSpecValidator.validate() interface

Raise AnsibleValidationErrorMultiple on validation error which contains all AnsibleValidationError
exceptions for validation failures.

Return the validated parameters if validation is successful rather than True/False.

Update docs and tests.

* Get attribute in loop so that the attribute name can also be used as a parameter

* Shorten line

* Update calling code in AnsibleModule for new validator interface

* Update calling code in validate_argument_spec based in new validation interface

* Base custom exception class off of Exception

* Call the __init__ method of the base Exception class to populate args

* Ensure no_log values are always updated

* Make custom exceptions more hierarchical

This redefines AnsibleError from lib/ansible/errors with a different signature since that cannot
be used by modules. This may be a bad idea. Maybe lib/ansible/errors should be moved to
module_utils, or AnsibleError defined in this commit should use the same signature as the original.

* Just go back to basing off Exception

* Return ValidationResult object on successful validation

Create a ValidationResult class.
Return a ValidationResult from ArgumentSpecValidator.validate() when validation is successful.
Update class and method docs.
Update unit tests based on interface change.

* Make it easier to get error objects from AnsibleValidationResultMultiple

This makes the interface cleaner when getting individual error objects contained in a single
AnsibleValidationResultMultiple instance.

* Define custom exception for each type of validation failure

These errors indicate where a validation error occured. Currently they are empty but could
contain specific data for each exception type in the future.

* Update tests based on (yet another) interface change

* Mark several more functions as private

These are all doing rather "internal" things. The ArgumentSpecValidator class is the preferred
public interface.

* Move warnings and deprecations to result object

Rather than calling deprecate() and warn() directly, store them on the result object so the
caller can decide what to do with them.

* Use subclass for module arg spec validation

The subclass uses global warning and deprecations feature

* Fix up docs

* Remove legal_inputs munging from _handle_aliases()

This is done in AnsibleModule by the _set_internal_properties() method. It only makes sense
to do that for an AnsibleModule instance (it should update the parameters before performing
validation) and shouldn't be done by the validator.

Create a private function just for getting legal inputs since that is done in a couple of places.

It may make sense store that on the ValidationResult object.

* Increase test coverage

* Remove unnecessary conditional

ci_complete

* Mark warnings and deprecations as private in the ValidationResult

They can be made public once we come up with a way to make them more generally useful,
probably by creating cusom objects to store the data in more structure way.

* Mark valid_parameter_names as private and populate it during initialization

* Use a global for storing the list of additonal checks to perform

This list is used by the main validate method as well as the sub spec validation.
This commit is contained in:
Sam Doran 2021-03-19 15:09:18 -04:00 committed by GitHub
parent 089d0a0508
commit abacf6a108
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 1138 additions and 1222 deletions

View file

@ -0,0 +1,5 @@
major_changes:
- >-
AnsibleModule - use ``ArgumentSpecValidator`` class for validating argument spec and remove
private methods related to argument spec validation. Any modules using private methods
should now use the ``ArgumentSpecValidator`` class or the appropriate validation function.

View file

@ -90,6 +90,8 @@ from ansible.module_utils.common.text.converters import (
container_to_text as json_dict_bytes_to_unicode,
)
from ansible.module_utils.common.arg_spec import ModuleArgumentSpecValidator
from ansible.module_utils.common.text.formatters import (
lenient_lowercase,
bytes_to_human,
@ -155,25 +157,15 @@ from ansible.module_utils.common.sys_info import (
)
from ansible.module_utils.pycompat24 import get_exception, literal_eval
from ansible.module_utils.common.parameters import (
_remove_values_conditions,
_sanitize_keys_conditions,
sanitize_keys,
env_fallback,
get_unsupported_parameters,
get_type_validator,
handle_aliases,
list_deprecations,
list_no_log_values,
remove_values,
set_defaults,
set_fallbacks,
validate_argument_types,
AnsibleFallbackNotFound,
sanitize_keys,
DEFAULT_TYPE_VALIDATORS,
PASS_VARS,
PASS_BOOLS,
)
from ansible.module_utils.errors import AnsibleFallbackNotFound, AnsibleValidationErrorMultiple, UnsupportedError
from ansible.module_utils.six import (
PY2,
PY3,
@ -187,24 +179,6 @@ from ansible.module_utils.six import (
from ansible.module_utils.six.moves import map, reduce, shlex_quote
from ansible.module_utils.common.validation import (
check_missing_parameters,
check_mutually_exclusive,
check_required_arguments,
check_required_by,
check_required_if,
check_required_one_of,
check_required_together,
count_terms,
check_type_bool,
check_type_bits,
check_type_bytes,
check_type_float,
check_type_int,
check_type_jsonarg,
check_type_list,
check_type_dict,
check_type_path,
check_type_raw,
check_type_str,
safe_eval,
)
from ansible.module_utils.common._utils import get_all_subclasses as _get_all_subclasses
@ -507,49 +481,44 @@ class AnsibleModule(object):
# Save parameter values that should never be logged
self.no_log_values = set()
self._load_params()
self._set_fallbacks()
# append to legal_inputs and then possibly check against them
try:
self.aliases = self._handle_aliases()
except (ValueError, TypeError) as e:
# Use exceptions here because it isn't safe to call fail_json until no_log is processed
print('\n{"failed": true, "msg": "Module alias error: %s"}' % to_native(e))
sys.exit(1)
self._handle_no_log_values()
# check the locale as set by the current environment, and reset to
# a known valid (LANG=C) if it's an invalid/unavailable locale
self._check_locale()
self._load_params()
self._set_internal_properties()
self._check_arguments()
# check exclusive early
if not bypass_checks:
self._check_mutually_exclusive(mutually_exclusive)
self.validator = ModuleArgumentSpecValidator(self.argument_spec,
self.mutually_exclusive,
self.required_together,
self.required_one_of,
self.required_if,
self.required_by,
)
self._set_defaults(pre=True)
self.validation_result = self.validator.validate(self.params)
self.params.update(self.validation_result.validated_parameters)
self.no_log_values.update(self.validation_result._no_log_values)
try:
error = self.validation_result.errors[0]
except IndexError:
error = None
# Fail for validation errors, even in check mode
if error:
msg = self.validation_result.errors.msg
if isinstance(error, UnsupportedError):
msg = "Unsupported parameters for ({name}) {kind}: {msg}".format(name=self._name, kind='module', msg=msg)
self.fail_json(msg=msg)
if self.check_mode and not self.supports_check_mode:
self.exit_json(skipped=True, msg="remote module (%s) does not support check mode" % self._name)
# This is for backwards compatibility only.
self._CHECK_ARGUMENT_TYPES_DISPATCHER = DEFAULT_TYPE_VALIDATORS
if not bypass_checks:
self._check_required_arguments()
self._check_argument_types()
self._check_argument_values()
self._check_required_together(required_together)
self._check_required_one_of(required_one_of)
self._check_required_if(required_if)
self._check_required_by(required_by)
self._set_defaults(pre=False)
# deal with options sub-spec
self._handle_options()
if not self.no_log:
self._log_invocation()
@ -1274,42 +1243,6 @@ class AnsibleModule(object):
self.fail_json(msg="An unknown error was encountered while attempting to validate the locale: %s" %
to_native(e), exception=traceback.format_exc())
def _handle_aliases(self, spec=None, param=None, option_prefix=''):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
# this uses exceptions as it happens before we can safely call fail_json
alias_warnings = []
alias_deprecations = []
alias_results, self._legal_inputs = handle_aliases(spec, param, alias_warnings, alias_deprecations)
for option, alias in alias_warnings:
warn('Both option %s and its alias %s are set.' % (option_prefix + option, option_prefix + alias))
for deprecation in alias_deprecations:
deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'],
version=deprecation.get('version'), date=deprecation.get('date'),
collection_name=deprecation.get('collection_name'))
return alias_results
def _handle_no_log_values(self, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
try:
self.no_log_values.update(list_no_log_values(spec, param))
except TypeError as te:
self.fail_json(msg="Failure when processing no_log parameters. Module invocation will be hidden. "
"%s" % to_native(te), invocation={'module_args': 'HIDDEN DUE TO FAILURE'})
for message in list_deprecations(spec, param):
deprecate(message['msg'], version=message.get('version'), date=message.get('date'),
collection_name=message.get('collection_name'))
def _set_internal_properties(self, argument_spec=None, module_parameters=None):
if argument_spec is None:
argument_spec = self.argument_spec
@ -1333,344 +1266,9 @@ class AnsibleModule(object):
if not hasattr(self, PASS_VARS[k][0]):
setattr(self, PASS_VARS[k][0], PASS_VARS[k][1])
def _check_arguments(self, spec=None, param=None, legal_inputs=None):
unsupported_parameters = set()
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
if legal_inputs is None:
legal_inputs = self._legal_inputs
unsupported_parameters = get_unsupported_parameters(spec, param, legal_inputs)
if unsupported_parameters:
msg = "Unsupported parameters for (%s) module: %s" % (self._name, ', '.join(sorted(list(unsupported_parameters))))
if self._options_context:
msg += " found in %s." % " -> ".join(self._options_context)
supported_parameters = list()
for key in sorted(spec.keys()):
if 'aliases' in spec[key] and spec[key]['aliases']:
supported_parameters.append("%s (%s)" % (key, ', '.join(sorted(spec[key]['aliases']))))
else:
supported_parameters.append(key)
msg += " Supported parameters include: %s" % (', '.join(supported_parameters))
self.fail_json(msg=msg)
if self.check_mode and not self.supports_check_mode:
self.exit_json(skipped=True, msg="remote module (%s) does not support check mode" % self._name)
def _count_terms(self, check, param=None):
if param is None:
param = self.params
return count_terms(check, param)
def _check_mutually_exclusive(self, spec, param=None):
if param is None:
param = self.params
try:
check_mutually_exclusive(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_one_of(self, spec, param=None):
if spec is None:
return
if param is None:
param = self.params
try:
check_required_one_of(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_together(self, spec, param=None):
if spec is None:
return
if param is None:
param = self.params
try:
check_required_together(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_by(self, spec, param=None):
if spec is None:
return
if param is None:
param = self.params
try:
check_required_by(spec, param)
except TypeError as e:
self.fail_json(msg=to_native(e))
def _check_required_arguments(self, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
try:
check_required_arguments(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_if(self, spec, param=None):
''' ensure that parameters which conditionally required are present '''
if spec is None:
return
if param is None:
param = self.params
try:
check_required_if(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_argument_values(self, spec=None, param=None):
''' ensure all arguments have the requested values, and there are no stray arguments '''
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
for (k, v) in spec.items():
choices = v.get('choices', None)
if choices is None:
continue
if isinstance(choices, SEQUENCETYPE) and not isinstance(choices, (binary_type, text_type)):
if k in param:
# Allow one or more when type='list' param with choices
if isinstance(param[k], list):
diff_list = ", ".join([item for item in param[k] if item not in choices])
if diff_list:
choices_str = ", ".join([to_native(c) for c in choices])
msg = "value of %s must be one or more of: %s. Got no match for: %s" % (k, choices_str, diff_list)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
elif param[k] not in choices:
# PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking
# the value. If we can't figure this out, module author is responsible.
lowered_choices = None
if param[k] == 'False':
lowered_choices = lenient_lowercase(choices)
overlap = BOOLEANS_FALSE.intersection(choices)
if len(overlap) == 1:
# Extract from a set
(param[k],) = overlap
if param[k] == 'True':
if lowered_choices is None:
lowered_choices = lenient_lowercase(choices)
overlap = BOOLEANS_TRUE.intersection(choices)
if len(overlap) == 1:
(param[k],) = overlap
if param[k] not in choices:
choices_str = ", ".join([to_native(c) for c in choices])
msg = "value of %s must be one of: %s, got: %s" % (k, choices_str, param[k])
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
else:
msg = "internal error: choices for argument %s are not iterable: %s" % (k, choices)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def safe_eval(self, value, locals=None, include_exceptions=False):
return safe_eval(value, locals, include_exceptions)
def _check_type_str(self, value, param=None, prefix=''):
opts = {
'error': False,
'warn': False,
'ignore': True
}
# Ignore, warn, or error when converting to a string.
allow_conversion = opts.get(self._string_conversion_action, True)
try:
return check_type_str(value, allow_conversion)
except TypeError:
common_msg = 'quote the entire value to ensure it does not change.'
from_msg = '{0!r}'.format(value)
to_msg = '{0!r}'.format(to_text(value))
if param is not None:
if prefix:
param = '{0}{1}'.format(prefix, param)
from_msg = '{0}: {1!r}'.format(param, value)
to_msg = '{0}: {1!r}'.format(param, to_text(value))
if self._string_conversion_action == 'error':
msg = common_msg.capitalize()
raise TypeError(to_native(msg))
elif self._string_conversion_action == 'warn':
msg = ('The value "{0}" (type {1.__class__.__name__}) was converted to "{2}" (type string). '
'If this does not look like what you expect, {3}').format(from_msg, value, to_msg, common_msg)
self.warn(to_native(msg))
return to_native(value, errors='surrogate_or_strict')
def _check_type_list(self, value):
return check_type_list(value)
def _check_type_dict(self, value):
return check_type_dict(value)
def _check_type_bool(self, value):
return check_type_bool(value)
def _check_type_int(self, value):
return check_type_int(value)
def _check_type_float(self, value):
return check_type_float(value)
def _check_type_path(self, value):
return check_type_path(value)
def _check_type_jsonarg(self, value):
return check_type_jsonarg(value)
def _check_type_raw(self, value):
return check_type_raw(value)
def _check_type_bytes(self, value):
return check_type_bytes(value)
def _check_type_bits(self, value):
return check_type_bits(value)
def _handle_options(self, argument_spec=None, params=None, prefix=''):
''' deal with options to create sub spec '''
if argument_spec is None:
argument_spec = self.argument_spec
if params is None:
params = self.params
for (k, v) in argument_spec.items():
wanted = v.get('type', None)
if wanted == 'dict' or (wanted == 'list' and v.get('elements', '') == 'dict'):
spec = v.get('options', None)
if v.get('apply_defaults', False):
if spec is not None:
if params.get(k) is None:
params[k] = {}
else:
continue
elif spec is None or k not in params or params[k] is None:
continue
self._options_context.append(k)
if isinstance(params[k], dict):
elements = [params[k]]
else:
elements = params[k]
for idx, param in enumerate(elements):
if not isinstance(param, dict):
self.fail_json(msg="value of %s must be of type dict or list of dict" % k)
new_prefix = prefix + k
if wanted == 'list':
new_prefix += '[%d]' % idx
new_prefix += '.'
self._set_fallbacks(spec, param)
options_aliases = self._handle_aliases(spec, param, option_prefix=new_prefix)
options_legal_inputs = list(spec.keys()) + list(options_aliases.keys())
self._check_arguments(spec, param, options_legal_inputs)
# check exclusive early
if not self.bypass_checks:
self._check_mutually_exclusive(v.get('mutually_exclusive', None), param)
self._set_defaults(pre=True, spec=spec, param=param)
if not self.bypass_checks:
self._check_required_arguments(spec, param)
self._check_argument_types(spec, param, new_prefix)
self._check_argument_values(spec, param)
self._check_required_together(v.get('required_together', None), param)
self._check_required_one_of(v.get('required_one_of', None), param)
self._check_required_if(v.get('required_if', None), param)
self._check_required_by(v.get('required_by', None), param)
self._set_defaults(pre=False, spec=spec, param=param)
# handle multi level options (sub argspec)
self._handle_options(spec, param, new_prefix)
self._options_context.pop()
def _get_wanted_type(self, wanted, k):
# Use the private method for 'str' type to handle the string conversion warning.
if wanted == 'str':
type_checker, wanted = self._check_type_str, 'str'
else:
type_checker, wanted = get_type_validator(wanted)
if type_checker is None:
self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k))
return type_checker, wanted
def _check_argument_types(self, spec=None, param=None, prefix=''):
''' ensure all arguments have the requested type '''
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
errors = []
validate_argument_types(spec, param, errors=errors)
if errors:
self.fail_json(msg=errors[0])
def _set_defaults(self, pre=True, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
# The interface for set_defaults is different than _set_defaults()
# The third parameter controls whether or not defaults are actually set.
set_default = not pre
self.no_log_values.update(set_defaults(spec, param, set_default))
def _set_fallbacks(self, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
self.no_log_values.update(set_fallbacks(spec, param))
def _load_params(self):
''' read the input and set the params attribute.

View file

@ -5,71 +5,146 @@
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from copy import deepcopy
from ansible.module_utils.common._collections_compat import (
Sequence,
)
from ansible.module_utils.common.parameters import (
get_unsupported_parameters,
handle_aliases,
list_no_log_values,
remove_values,
set_defaults,
_ADDITIONAL_CHECKS,
_get_legal_inputs,
_get_unsupported_parameters,
_handle_aliases,
_list_no_log_values,
_set_defaults,
_validate_argument_types,
_validate_argument_values,
_validate_sub_spec,
set_fallbacks,
validate_argument_types,
validate_argument_values,
validate_sub_spec,
)
from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.warnings import deprecate, warn
from ansible.module_utils.common.validation import (
check_mutually_exclusive,
check_required_arguments,
check_required_by,
check_required_if,
check_required_one_of,
check_required_together,
)
from ansible.module_utils.six import string_types
from ansible.module_utils.errors import (
AliasError,
AnsibleValidationErrorMultiple,
MutuallyExclusiveError,
NoLogError,
RequiredByError,
RequiredDefaultError,
RequiredError,
RequiredIfError,
RequiredOneOfError,
RequiredTogetherError,
UnsupportedError,
)
class ArgumentSpecValidator():
"""Argument spec validation class"""
class ValidationResult:
"""Result of argument spec validation.
def __init__(self, argument_spec, parameters):
self._error_messages = []
:param parameters: Terms to be validated and coerced to the correct type.
:type parameters: dict
"""
def __init__(self, parameters):
self._no_log_values = set()
self.argument_spec = argument_spec
# Make a copy of the original parameters to avoid changing them
self._validated_parameters = deepcopy(parameters)
self._unsupported_parameters = set()
@property
def error_messages(self):
return self._error_messages
self._validated_parameters = deepcopy(parameters)
self._deprecations = []
self._warnings = []
self.errors = AnsibleValidationErrorMultiple()
@property
def validated_parameters(self):
return self._validated_parameters
def _add_error(self, error):
if isinstance(error, string_types):
self._error_messages.append(error)
elif isinstance(error, Sequence):
self._error_messages.extend(error)
else:
raise ValueError('Error messages must be a string or sequence not a %s' % type(error))
@property
def unsupported_parameters(self):
return self._unsupported_parameters
def _sanitize_error_messages(self):
self._error_messages = remove_values(self._error_messages, self._no_log_values)
@property
def error_messages(self):
return self.errors.messages
def validate(self, *args, **kwargs):
"""Validate module parameters against argument spec.
class ArgumentSpecValidator:
"""Argument spec validation class
Creates a validator based on the ``argument_spec`` that can be used to
validate a number of parameters using the ``validate()`` method.
:param argument_spec: Specification of valid parameters and their type. May
include nested argument specs.
:type argument_spec: dict
:param mutually_exclusive: List or list of lists of terms that should not
be provided together.
:type mutually_exclusive: list, optional
:param required_together: List of lists of terms that are required together.
:type required_together: list, optional
:param required_one_of: List of lists of terms, one of which in each list
is required.
:type required_one_of: list, optional
:param required_if: List of lists of ``[parameter, value, [parameters]]`` where
one of [parameters] is required if ``parameter`` == ``value``.
:type required_if: list, optional
:param required_by: Dictionary of parameter names that contain a list of
parameters required by each key in the dictionary.
:type required_by: dict, optional
"""
def __init__(self, argument_spec,
mutually_exclusive=None,
required_together=None,
required_one_of=None,
required_if=None,
required_by=None,
):
self._mutually_exclusive = mutually_exclusive
self._required_together = required_together
self._required_one_of = required_one_of
self._required_if = required_if
self._required_by = required_by
self._valid_parameter_names = set()
self.argument_spec = argument_spec
for key in sorted(self.argument_spec.keys()):
aliases = self.argument_spec[key].get('aliases')
if aliases:
self._valid_parameter_names.update(["{key} ({aliases})".format(key=key, aliases=", ".join(sorted(aliases)))])
else:
self._valid_parameter_names.update([key])
def validate(self, parameters, *args, **kwargs):
"""Validate module parameters against argument spec. Returns a
ValidationResult object.
Error messages in the ValidationResult may contain no_log values and should be
sanitized before logging or displaying.
:Example:
validator = ArgumentSpecValidator(argument_spec, parameters)
passeded = validator.validate()
validator = ArgumentSpecValidator(argument_spec)
result = validator.validate(parameters)
if result.error_messages:
sys.exit("Validation failed: {0}".format(", ".join(result.error_messages))
valid_params = result.validated_parameters
:param argument_spec: Specification of parameters, type, and valid values
:type argument_spec: dict
@ -77,58 +152,104 @@ class ArgumentSpecValidator():
:param parameters: Parameters provided to the role
:type parameters: dict
:returns: True if no errors were encountered, False if any errors were encountered.
:rtype: bool
:return: Object containing validated parameters.
:rtype: ValidationResult
"""
self._no_log_values.update(set_fallbacks(self.argument_spec, self._validated_parameters))
result = ValidationResult(parameters)
result._no_log_values.update(set_fallbacks(self.argument_spec, result._validated_parameters))
alias_warnings = []
alias_deprecations = []
try:
alias_results, legal_inputs = handle_aliases(self.argument_spec, self._validated_parameters, alias_warnings, alias_deprecations)
aliases = _handle_aliases(self.argument_spec, result._validated_parameters, alias_warnings, alias_deprecations)
except (TypeError, ValueError) as e:
alias_results = {}
legal_inputs = None
self._add_error(to_native(e))
aliases = {}
result.errors.append(AliasError(to_native(e)))
legal_inputs = _get_legal_inputs(self.argument_spec, result._validated_parameters, aliases)
for option, alias in alias_warnings:
warn('Both option %s and its alias %s are set.' % (option, alias))
result._warnings.append({'option': option, 'alias': alias})
for deprecation in alias_deprecations:
deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'],
version=deprecation.get('version'), date=deprecation.get('date'),
collection_name=deprecation.get('collection_name'))
self._no_log_values.update(list_no_log_values(self.argument_spec, self._validated_parameters))
if legal_inputs is None:
legal_inputs = list(alias_results.keys()) + list(self.argument_spec.keys())
self._unsupported_parameters.update(get_unsupported_parameters(self.argument_spec, self._validated_parameters, legal_inputs))
self._no_log_values.update(set_defaults(self.argument_spec, self._validated_parameters, False))
result._deprecations.append({
'name': deprecation['name'],
'version': deprecation.get('version'),
'date': deprecation.get('date'),
'collection_name': deprecation.get('collection_name'),
})
try:
check_required_arguments(self.argument_spec, self._validated_parameters)
result._no_log_values.update(_list_no_log_values(self.argument_spec, result._validated_parameters))
except TypeError as te:
result.errors.append(NoLogError(to_native(te)))
try:
result._unsupported_parameters.update(_get_unsupported_parameters(self.argument_spec, result._validated_parameters, legal_inputs))
except TypeError as te:
result.errors.append(RequiredDefaultError(to_native(te)))
except ValueError as ve:
result.errors.append(AliasError(to_native(ve)))
try:
check_mutually_exclusive(self._mutually_exclusive, result._validated_parameters)
except TypeError as te:
result.errors.append(MutuallyExclusiveError(to_native(te)))
result._no_log_values.update(_set_defaults(self.argument_spec, result._validated_parameters, False))
try:
check_required_arguments(self.argument_spec, result._validated_parameters)
except TypeError as e:
self._add_error(to_native(e))
result.errors.append(RequiredError(to_native(e)))
validate_argument_types(self.argument_spec, self._validated_parameters, errors=self._error_messages)
validate_argument_values(self.argument_spec, self._validated_parameters, errors=self._error_messages)
_validate_argument_types(self.argument_spec, result._validated_parameters, errors=result.errors)
_validate_argument_values(self.argument_spec, result._validated_parameters, errors=result.errors)
self._no_log_values.update(set_defaults(self.argument_spec, self._validated_parameters))
for check in _ADDITIONAL_CHECKS:
try:
check['func'](getattr(self, "_{attr}".format(attr=check['attr'])), result._validated_parameters)
except TypeError as te:
result.errors.append(check['err'](to_native(te)))
validate_sub_spec(self.argument_spec, self._validated_parameters,
errors=self._error_messages,
no_log_values=self._no_log_values,
unsupported_parameters=self._unsupported_parameters)
result._no_log_values.update(_set_defaults(self.argument_spec, result._validated_parameters))
if self._unsupported_parameters:
self._add_error('Unsupported parameters: %s' % ', '.join(sorted(list(self._unsupported_parameters))))
_validate_sub_spec(self.argument_spec, result._validated_parameters,
errors=result.errors,
no_log_values=result._no_log_values,
unsupported_parameters=result._unsupported_parameters)
self._sanitize_error_messages()
if result._unsupported_parameters:
flattened_names = []
for item in result._unsupported_parameters:
if isinstance(item, tuple):
flattened_names.append(".".join(item))
else:
flattened_names.append(item)
if self.error_messages:
return False
else:
return True
unsupported_string = ", ".join(sorted(list(flattened_names)))
supported_string = ", ".join(self._valid_parameter_names)
result.errors.append(
UnsupportedError("{0}. Supported parameters include: {1}.".format(unsupported_string, supported_string)))
return result
class ModuleArgumentSpecValidator(ArgumentSpecValidator):
def __init__(self, *args, **kwargs):
super(ModuleArgumentSpecValidator, self).__init__(*args, **kwargs)
def validate(self, parameters):
result = super(ModuleArgumentSpecValidator, self).validate(parameters)
for d in result._deprecations:
deprecate("Alias '{name}' is deprecated. See the module docs for more information".format(name=d['name']),
version=d.get('version'), date=d.get('date'),
collection_name=d.get('collection_name'))
for w in result._warnings:
warn('Both option {option} and its alias {alias} are set.'.format(option=w['option'], alias=w['alias']))
return result

File diff suppressed because it is too large Load diff

View file

@ -39,7 +39,35 @@ def count_terms(terms, parameters):
return len(set(terms).intersection(parameters))
def check_mutually_exclusive(terms, parameters):
def safe_eval(value, locals=None, include_exceptions=False):
# do not allow method calls to modules
if not isinstance(value, string_types):
# already templated to a datavaluestructure, perhaps?
if include_exceptions:
return (value, None)
return value
if re.search(r'\w\.\w+\(', value):
if include_exceptions:
return (value, None)
return value
# do not allow imports
if re.search(r'import \w+', value):
if include_exceptions:
return (value, None)
return value
try:
result = literal_eval(value)
if include_exceptions:
return (result, None)
else:
return result
except Exception as e:
if include_exceptions:
return (value, e)
return value
def check_mutually_exclusive(terms, parameters, options_context=None):
"""Check mutually exclusive terms against argument parameters
Accepts a single list or list of lists that are groups of terms that should be
@ -63,12 +91,14 @@ def check_mutually_exclusive(terms, parameters):
if results:
full_list = ['|'.join(check) for check in results]
msg = "parameters are mutually exclusive: %s" % ', '.join(full_list)
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg))
return results
def check_required_one_of(terms, parameters):
def check_required_one_of(terms, parameters, options_context=None):
"""Check each list of terms to ensure at least one exists in the given module
parameters
@ -93,12 +123,14 @@ def check_required_one_of(terms, parameters):
if results:
for term in results:
msg = "one of the following is required: %s" % ', '.join(term)
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg))
return results
def check_required_together(terms, parameters):
def check_required_together(terms, parameters, options_context=None):
"""Check each list of terms to ensure every parameter in each list exists
in the given parameters
@ -125,12 +157,14 @@ def check_required_together(terms, parameters):
if results:
for term in results:
msg = "parameters are required together: %s" % ', '.join(term)
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg))
return results
def check_required_by(requirements, parameters):
def check_required_by(requirements, parameters, options_context=None):
"""For each key in requirements, check the corresponding list to see if they
exist in parameters
@ -161,12 +195,14 @@ def check_required_by(requirements, parameters):
for key, missing in result.items():
if len(missing) > 0:
msg = "missing parameter(s) required by '%s': %s" % (key, ', '.join(missing))
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg))
return result
def check_required_arguments(argument_spec, parameters):
def check_required_arguments(argument_spec, parameters, options_context=None):
"""Check all paramaters in argument_spec and return a list of parameters
that are required but not present in parameters
@ -190,12 +226,14 @@ def check_required_arguments(argument_spec, parameters):
if missing:
msg = "missing required arguments: %s" % ", ".join(sorted(missing))
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg))
return missing
def check_required_if(requirements, parameters):
def check_required_if(requirements, parameters, options_context=None):
"""Check parameters that are conditionally required
Raises TypeError if the check fails
@ -272,6 +310,8 @@ def check_required_if(requirements, parameters):
for missing in results:
msg = "%s is %s but %s of the following are missing: %s" % (
missing['parameter'], missing['value'], missing['requires'], ', '.join(missing['missing']))
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg))
return results
@ -304,34 +344,6 @@ def check_missing_parameters(parameters, required_parameters=None):
return missing_params
def safe_eval(value, locals=None, include_exceptions=False):
# do not allow method calls to modules
if not isinstance(value, string_types):
# already templated to a datavaluestructure, perhaps?
if include_exceptions:
return (value, None)
return value
if re.search(r'\w\.\w+\(', value):
if include_exceptions:
return (value, None)
return value
# do not allow imports
if re.search(r'import \w+', value):
if include_exceptions:
return (value, None)
return value
try:
result = literal_eval(value)
if include_exceptions:
return (result, None)
else:
return result
except Exception as e:
if include_exceptions:
return (value, e)
return value
# FIXME: The param and prefix parameters here are coming from AnsibleModule._check_type_string()
# which is using those for the warning messaged based on string conversion warning settings.
# Not sure how to deal with that here since we don't have config state to query.

View file

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
class AnsibleFallbackNotFound(Exception):
"""Fallback validator was not found"""
class AnsibleValidationError(Exception):
"""Single argument spec validation error"""
def __init__(self, message):
super(AnsibleValidationError, self).__init__(message)
self.error_message = message
@property
def msg(self):
return self.args[0]
class AnsibleValidationErrorMultiple(AnsibleValidationError):
"""Multiple argument spec validation errors"""
def __init__(self, errors=None):
self.errors = errors[:] if errors else []
def __getitem__(self, key):
return self.errors[key]
def __setitem__(self, key, value):
self.errors[key] = value
def __delitem__(self, key):
del self.errors[key]
@property
def msg(self):
return self.errors[0].args[0]
@property
def messages(self):
return [err.msg for err in self.errors]
def append(self, error):
self.errors.append(error)
def extend(self, errors):
self.errors.extend(errors)
class AliasError(AnsibleValidationError):
"""Error handling aliases"""
class ArgumentTypeError(AnsibleValidationError):
"""Error with parameter type"""
class ArgumentValueError(AnsibleValidationError):
"""Error with parameter value"""
class ElementError(AnsibleValidationError):
"""Error when validating elements"""
class MutuallyExclusiveError(AnsibleValidationError):
"""Mutually exclusive parameters were supplied"""
class NoLogError(AnsibleValidationError):
"""Error converting no_log values"""
class RequiredByError(AnsibleValidationError):
"""Error with parameters that are required by other parameters"""
class RequiredDefaultError(AnsibleValidationError):
"""A required parameter was assigned a default value"""
class RequiredError(AnsibleValidationError):
"""Missing a required parameter"""
class RequiredIfError(AnsibleValidationError):
"""Error with conditionally required parameters"""
class RequiredOneOfError(AnsibleValidationError):
"""Error with parameters where at least one is required"""
class RequiredTogetherError(AnsibleValidationError):
"""Error with parameters that are required together"""
class SubParameterTypeError(AnsibleValidationError):
"""Incorrect type for subparameter"""
class UnsupportedError(AnsibleValidationError):
"""Unsupported parameters were supplied"""

View file

@ -8,6 +8,7 @@ from ansible.errors import AnsibleError
from ansible.plugins.action import ActionBase
from ansible.module_utils.six import iteritems, string_types
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
from ansible.module_utils.errors import AnsibleValidationErrorMultiple
class ActionModule(ActionBase):
@ -82,13 +83,14 @@ class ActionModule(ActionBase):
args_from_vars = self.get_args_from_task_vars(argument_spec_data, task_vars)
provided_arguments.update(args_from_vars)
validator = ArgumentSpecValidator(argument_spec_data, provided_arguments)
validator = ArgumentSpecValidator(argument_spec_data)
validation_result = validator.validate(provided_arguments)
if not validator.validate():
if validation_result.error_messages:
result['failed'] = True
result['msg'] = 'Validation of arguments failed:\n%s' % '\n'.join(validator.error_messages)
result['msg'] = 'Validation of arguments failed:\n%s' % '\n'.join(validation_result.error_messages)
result['argument_spec_data'] = argument_spec_data
result['argument_errors'] = validator.error_messages
result['argument_errors'] = validation_result.error_messages
return result
result['changed'] = False

View file

@ -29,7 +29,6 @@ from io import BytesIO
import ansible.errors
from ansible.executor.module_common import recursive_finder
from ansible.module_utils.six import PY2
# These are the modules that are brought in by module_utils/basic.py This may need to be updated
@ -58,12 +57,14 @@ MODULE_UTILS_BASIC_FILES = frozenset(('ansible/__init__.py',
'ansible/module_utils/common/text/formatters.py',
'ansible/module_utils/common/validation.py',
'ansible/module_utils/common/_utils.py',
'ansible/module_utils/common/arg_spec.py',
'ansible/module_utils/compat/__init__.py',
'ansible/module_utils/compat/_selectors2.py',
'ansible/module_utils/compat/selectors.py',
'ansible/module_utils/compat/selinux.py',
'ansible/module_utils/distro/__init__.py',
'ansible/module_utils/distro/_distro.py',
'ansible/module_utils/errors.py',
'ansible/module_utils/parsing/__init__.py',
'ansible/module_utils/parsing/convert_bool.py',
'ansible/module_utils/pycompat24.py',

View file

@ -84,9 +84,9 @@ INVALID_SPECS = (
({'arg': {'type': 'list', 'elements': MOCK_VALIDATOR_FAIL}}, {'arg': [1, "bad"]}, "bad conversion"),
# unknown parameter
({'arg': {'type': 'int'}}, {'other': 'bad', '_ansible_module_name': 'ansible_unittest'},
'Unsupported parameters for (ansible_unittest) module: other Supported parameters include: arg'),
'Unsupported parameters for (ansible_unittest) module: other. Supported parameters include: arg.'),
({'arg': {'type': 'int', 'aliases': ['argument']}}, {'other': 'bad', '_ansible_module_name': 'ansible_unittest'},
'Unsupported parameters for (ansible_unittest) module: other Supported parameters include: arg (argument)'),
'Unsupported parameters for (ansible_unittest) module: other. Supported parameters include: arg (argument).'),
# parameter is required
({'arg': {'required': True}}, {}, 'missing required arguments: arg'),
)
@ -496,7 +496,7 @@ class TestComplexOptions:
# Missing required option
({'foobar': [{}]}, 'missing required arguments: foo found in foobar'),
# Invalid option
({'foobar': [{"foo": "hello", "bam": "good", "invalid": "bad"}]}, 'module: invalid found in foobar. Supported parameters include'),
({'foobar': [{"foo": "hello", "bam": "good", "invalid": "bad"}]}, 'module: foobar.invalid. Supported parameters include'),
# Mutually exclusive options found
({'foobar': [{"foo": "test", "bam": "bad", "bam1": "bad", "baz": "req_to"}]},
'parameters are mutually exclusive: bam|bam1 found in foobar'),
@ -520,7 +520,7 @@ class TestComplexOptions:
({'foobar': {}}, 'missing required arguments: foo found in foobar'),
# Invalid option
({'foobar': {"foo": "hello", "bam": "good", "invalid": "bad"}},
'module: invalid found in foobar. Supported parameters include'),
'module: foobar.invalid. Supported parameters include'),
# Mutually exclusive options found
({'foobar': {"foo": "test", "bam": "bad", "bam1": "bad", "baz": "req_to"}},
'parameters are mutually exclusive: bam|bam1 found in foobar'),

View file

@ -1,28 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
def test_add_sequence():
v = ArgumentSpecValidator({}, {})
errors = [
'one error',
'another error',
]
v._add_error(errors)
assert v.error_messages == errors
def test_invalid_error_message():
v = ArgumentSpecValidator({}, {})
with pytest.raises(ValueError, match="Error messages must be a string or sequence not a"):
v._add_error(None)

View file

@ -7,10 +7,11 @@ __metaclass__ = type
import pytest
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
from ansible.module_utils.errors import AnsibleValidationError, AnsibleValidationErrorMultiple
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult
from ansible.module_utils.common.warnings import get_deprecation_messages, get_warning_messages
# id, argument spec, parameters, expected parameters, expected pass/fail, error, deprecation, warning
# id, argument spec, parameters, expected parameters, deprecation, warning
ALIAS_TEST_CASES = [
(
"alias",
@ -20,29 +21,6 @@ ALIAS_TEST_CASES = [
'dir': '/tmp',
'path': '/tmp',
},
True,
"",
"",
"",
),
(
"alias-invalid",
{'path': {'aliases': 'bad'}},
{},
{'path': None},
False,
"internal error: aliases must be a list or tuple",
"",
"",
),
(
# This isn't related to aliases, but it exists in the alias handling code
"default-and-required",
{'name': {'default': 'ray', 'required': True}},
{},
{'name': 'ray'},
False,
"internal error: required and default are mutually exclusive for name",
"",
"",
),
@ -58,10 +36,8 @@ ALIAS_TEST_CASES = [
'directory': '/tmp',
'path': '/tmp',
},
True,
"",
"",
"Both option path and its alias directory are set",
{'alias': 'directory', 'option': 'path'},
),
(
"deprecated-alias",
@ -81,39 +57,66 @@ ALIAS_TEST_CASES = [
'path': '/tmp',
'not_yo_path': '/tmp',
},
True,
"",
"Alias 'not_yo_path' is deprecated.",
{'version': '1.7', 'date': None, 'collection_name': None, 'name': 'not_yo_path'},
"",
)
]
# id, argument spec, parameters, expected parameters, error
ALIAS_TEST_CASES_INVALID = [
(
"alias-invalid",
{'path': {'aliases': 'bad'}},
{},
{'path': None},
"internal error: aliases must be a list or tuple",
),
(
# This isn't related to aliases, but it exists in the alias handling code
"default-and-required",
{'name': {'default': 'ray', 'required': True}},
{},
{'name': 'ray'},
"internal error: required and default are mutually exclusive for name",
),
]
@pytest.mark.parametrize(
('arg_spec', 'parameters', 'expected', 'passfail', 'error', 'deprecation', 'warning'),
((i[1], i[2], i[3], i[4], i[5], i[6], i[7]) for i in ALIAS_TEST_CASES),
('arg_spec', 'parameters', 'expected', 'deprecation', 'warning'),
((i[1:]) for i in ALIAS_TEST_CASES),
ids=[i[0] for i in ALIAS_TEST_CASES]
)
def test_aliases(arg_spec, parameters, expected, passfail, error, deprecation, warning):
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
def test_aliases(arg_spec, parameters, expected, deprecation, warning):
v = ArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
assert passed is passfail
assert v.validated_parameters == expected
assert isinstance(result, ValidationResult)
assert result.validated_parameters == expected
assert result.error_messages == []
if not error:
assert v.error_messages == []
if deprecation:
assert deprecation == result._deprecations[0]
else:
assert error in v.error_messages[0]
assert result._deprecations == []
deprecations = get_deprecation_messages()
if not deprecations:
assert deprecations == ()
if warning:
assert warning == result._warnings[0]
else:
assert deprecation in get_deprecation_messages()[0]['msg']
assert result._warnings == []
warnings = get_warning_messages()
if not warning:
assert warnings == ()
else:
assert warning in warnings[0]
@pytest.mark.parametrize(
('arg_spec', 'parameters', 'expected', 'error'),
((i[1:]) for i in ALIAS_TEST_CASES_INVALID),
ids=[i[0] for i in ALIAS_TEST_CASES_INVALID]
)
def test_aliases_invalid(arg_spec, parameters, expected, error):
v = ArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
assert isinstance(result, ValidationResult)
assert error in result.error_messages
assert isinstance(result.errors.errors[0], AnsibleValidationError)
assert isinstance(result.errors, AnsibleValidationErrorMultiple)

View file

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import ansible.module_utils.common.warnings as warnings
from ansible.module_utils.common.arg_spec import ModuleArgumentSpecValidator, ValidationResult
def test_module_validate():
arg_spec = {'name': {}}
parameters = {'name': 'larry'}
expected = {'name': 'larry'}
v = ModuleArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
assert isinstance(result, ValidationResult)
assert result.error_messages == []
assert result._deprecations == []
assert result._warnings == []
assert result.validated_parameters == expected
def test_module_alias_deprecations_warnings():
arg_spec = {
'path': {
'aliases': ['source', 'src', 'flamethrower'],
'deprecated_aliases': [{'name': 'flamethrower', 'date': '2020-03-04'}],
},
}
parameters = {'flamethrower': '/tmp', 'source': '/tmp'}
expected = {
'path': '/tmp',
'flamethrower': '/tmp',
'source': '/tmp',
}
v = ModuleArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
assert result.validated_parameters == expected
assert result._deprecations == [
{
'collection_name': None,
'date': '2020-03-04',
'name': 'flamethrower',
'version': None,
}
]
assert "Alias 'flamethrower' is deprecated" in warnings._global_deprecations[0]['msg']
assert result._warnings == [{'alias': 'flamethrower', 'option': 'path'}]
assert "Both option path and its alias flamethrower are set" in warnings._global_warnings[0]

View file

@ -5,7 +5,7 @@
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult
def test_sub_spec():
@ -39,12 +39,12 @@ def test_sub_spec():
}
}
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
v = ArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
assert passed is True
assert v.error_messages == []
assert v.validated_parameters == expected
assert isinstance(result, ValidationResult)
assert result.validated_parameters == expected
assert result.error_messages == []
def test_nested_sub_spec():
@ -98,9 +98,9 @@ def test_nested_sub_spec():
}
}
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
v = ArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
assert passed is True
assert v.error_messages == []
assert v.validated_parameters == expected
assert isinstance(result, ValidationResult)
assert result.validated_parameters == expected
assert result.error_messages == []

View file

@ -7,17 +7,19 @@ __metaclass__ = type
import pytest
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult
from ansible.module_utils.errors import AnsibleValidationErrorMultiple
from ansible.module_utils.six import PY2
# Each item is id, argument_spec, parameters, expected, error test string
# Each item is id, argument_spec, parameters, expected, unsupported parameters, error test string
INVALID_SPECS = [
(
'invalid-list',
{'packages': {'type': 'list'}},
{'packages': {'key': 'value'}},
{'packages': {'key': 'value'}},
set(),
"unable to convert to list: <class 'dict'> cannot be converted to a list",
),
(
@ -25,6 +27,7 @@ INVALID_SPECS = [
{'users': {'type': 'dict'}},
{'users': ['one', 'two']},
{'users': ['one', 'two']},
set(),
"unable to convert to dict: <class 'list'> cannot be converted to a dict",
),
(
@ -32,6 +35,7 @@ INVALID_SPECS = [
{'bool': {'type': 'bool'}},
{'bool': {'k': 'v'}},
{'bool': {'k': 'v'}},
set(),
"unable to convert to bool: <class 'dict'> cannot be converted to a bool",
),
(
@ -39,6 +43,7 @@ INVALID_SPECS = [
{'float': {'type': 'float'}},
{'float': 'hello'},
{'float': 'hello'},
set(),
"unable to convert to float: <class 'str'> cannot be converted to a float",
),
(
@ -46,6 +51,7 @@ INVALID_SPECS = [
{'bytes': {'type': 'bytes'}},
{'bytes': 'one'},
{'bytes': 'one'},
set(),
"unable to convert to bytes: <class 'str'> cannot be converted to a Byte value",
),
(
@ -53,6 +59,7 @@ INVALID_SPECS = [
{'bits': {'type': 'bits'}},
{'bits': 'one'},
{'bits': 'one'},
set(),
"unable to convert to bits: <class 'str'> cannot be converted to a Bit value",
),
(
@ -60,6 +67,7 @@ INVALID_SPECS = [
{'some_json': {'type': 'jsonarg'}},
{'some_json': set()},
{'some_json': set()},
set(),
"unable to convert to jsonarg: <class 'set'> cannot be converted to a json string",
),
(
@ -74,13 +82,15 @@ INVALID_SPECS = [
'badparam': '',
'another': '',
},
"Unsupported parameters: another, badparam",
set(('another', 'badparam')),
"another, badparam. Supported parameters include: name.",
),
(
'invalid-elements',
{'numbers': {'type': 'list', 'elements': 'int'}},
{'numbers': [55, 33, 34, {'key': 'value'}]},
{'numbers': [55, 33, 34]},
set(),
"Elements value for option 'numbers' is of type <class 'dict'> and we were unable to convert to int: <class 'dict'> cannot be converted to an int"
),
(
@ -88,23 +98,29 @@ INVALID_SPECS = [
{'req': {'required': True}},
{},
{'req': None},
set(),
"missing required arguments: req"
)
]
@pytest.mark.parametrize(
('arg_spec', 'parameters', 'expected', 'error'),
((i[1], i[2], i[3], i[4]) for i in INVALID_SPECS),
('arg_spec', 'parameters', 'expected', 'unsupported', 'error'),
(i[1:] for i in INVALID_SPECS),
ids=[i[0] for i in INVALID_SPECS]
)
def test_invalid_spec(arg_spec, parameters, expected, error):
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
def test_invalid_spec(arg_spec, parameters, expected, unsupported, error):
v = ArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
with pytest.raises(AnsibleValidationErrorMultiple) as exc_info:
raise result.errors
if PY2:
error = error.replace('class', 'type')
assert error in v.error_messages[0]
assert v.validated_parameters == expected
assert passed is False
assert isinstance(result, ValidationResult)
assert error in exc_info.value.msg
assert error in result.error_messages[0]
assert result.unsupported_parameters == unsupported
assert result.validated_parameters == expected

View file

@ -7,45 +7,53 @@ __metaclass__ = type
import pytest
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
import ansible.module_utils.common.warnings as warnings
# Each item is id, argument_spec, parameters, expected
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult
# Each item is id, argument_spec, parameters, expected, valid parameter names
VALID_SPECS = [
(
'str-no-type-specified',
{'name': {}},
{'name': 'rey'},
{'name': 'rey'},
set(('name',)),
),
(
'str',
{'name': {'type': 'str'}},
{'name': 'rey'},
{'name': 'rey'},
set(('name',)),
),
(
'str-convert',
{'name': {'type': 'str'}},
{'name': 5},
{'name': '5'},
set(('name',)),
),
(
'list',
{'packages': {'type': 'list'}},
{'packages': ['vim', 'python']},
{'packages': ['vim', 'python']},
set(('packages',)),
),
(
'list-comma-string',
{'packages': {'type': 'list'}},
{'packages': 'vim,python'},
{'packages': ['vim', 'python']},
set(('packages',)),
),
(
'list-comma-string-space',
{'packages': {'type': 'list'}},
{'packages': 'vim, python'},
{'packages': ['vim', ' python']},
set(('packages',)),
),
(
'dict',
@ -64,6 +72,7 @@ VALID_SPECS = [
'last': 'skywalker',
}
},
set(('user',)),
),
(
'dict-k=v',
@ -76,6 +85,7 @@ VALID_SPECS = [
'last': 'skywalker',
}
},
set(('user',)),
),
(
'dict-k=v-spaces',
@ -88,6 +98,7 @@ VALID_SPECS = [
'last': 'skywalker',
}
},
set(('user',)),
),
(
'bool',
@ -103,6 +114,7 @@ VALID_SPECS = [
'enabled': True,
'disabled': False,
},
set(('enabled', 'disabled')),
),
(
'bool-ints',
@ -118,6 +130,7 @@ VALID_SPECS = [
'enabled': True,
'disabled': False,
},
set(('enabled', 'disabled')),
),
(
'bool-true-false',
@ -133,6 +146,7 @@ VALID_SPECS = [
'enabled': True,
'disabled': False,
},
set(('enabled', 'disabled')),
),
(
'bool-yes-no',
@ -148,6 +162,7 @@ VALID_SPECS = [
'enabled': True,
'disabled': False,
},
set(('enabled', 'disabled')),
),
(
'bool-y-n',
@ -163,6 +178,7 @@ VALID_SPECS = [
'enabled': True,
'disabled': False,
},
set(('enabled', 'disabled')),
),
(
'bool-on-off',
@ -178,6 +194,7 @@ VALID_SPECS = [
'enabled': True,
'disabled': False,
},
set(('enabled', 'disabled')),
),
(
'bool-1-0',
@ -193,6 +210,7 @@ VALID_SPECS = [
'enabled': True,
'disabled': False,
},
set(('enabled', 'disabled')),
),
(
'bool-float',
@ -208,89 +226,112 @@ VALID_SPECS = [
'enabled': True,
'disabled': False,
},
set(('enabled', 'disabled')),
),
(
'float',
{'digit': {'type': 'float'}},
{'digit': 3.14159},
{'digit': 3.14159},
set(('digit',)),
),
(
'float-str',
{'digit': {'type': 'float'}},
{'digit': '3.14159'},
{'digit': 3.14159},
set(('digit',)),
),
(
'path',
{'path': {'type': 'path'}},
{'path': '~/bin'},
{'path': '/home/ansible/bin'},
set(('path',)),
),
(
'raw',
{'raw': {'type': 'raw'}},
{'raw': 0x644},
{'raw': 0x644},
set(('raw',)),
),
(
'bytes',
{'bytes': {'type': 'bytes'}},
{'bytes': '2K'},
{'bytes': 2048},
set(('bytes',)),
),
(
'bits',
{'bits': {'type': 'bits'}},
{'bits': '1Mb'},
{'bits': 1048576},
set(('bits',)),
),
(
'jsonarg',
{'some_json': {'type': 'jsonarg'}},
{'some_json': '{"users": {"bob": {"role": "accountant"}}}'},
{'some_json': '{"users": {"bob": {"role": "accountant"}}}'},
set(('some_json',)),
),
(
'jsonarg-list',
{'some_json': {'type': 'jsonarg'}},
{'some_json': ['one', 'two']},
{'some_json': '["one", "two"]'},
set(('some_json',)),
),
(
'jsonarg-dict',
{'some_json': {'type': 'jsonarg'}},
{'some_json': {"users": {"bob": {"role": "accountant"}}}},
{'some_json': '{"users": {"bob": {"role": "accountant"}}}'},
set(('some_json',)),
),
(
'defaults',
{'param': {'default': 'DEFAULT'}},
{},
{'param': 'DEFAULT'},
set(('param',)),
),
(
'elements',
{'numbers': {'type': 'list', 'elements': 'int'}},
{'numbers': [55, 33, 34, '22']},
{'numbers': [55, 33, 34, 22]},
set(('numbers',)),
),
(
'aliases',
{'src': {'aliases': ['path', 'source']}},
{'src': '/tmp'},
{'src': '/tmp'},
set(('src (path, source)',)),
)
]
@pytest.mark.parametrize(
('arg_spec', 'parameters', 'expected'),
((i[1], i[2], i[3]) for i in VALID_SPECS),
('arg_spec', 'parameters', 'expected', 'valid_params'),
(i[1:] for i in VALID_SPECS),
ids=[i[0] for i in VALID_SPECS]
)
def test_valid_spec(arg_spec, parameters, expected, mocker):
def test_valid_spec(arg_spec, parameters, expected, valid_params, mocker):
mocker.patch('ansible.module_utils.common.validation.os.path.expanduser', return_value='/home/ansible/bin')
mocker.patch('ansible.module_utils.common.validation.os.path.expandvars', return_value='/home/ansible/bin')
v = ArgumentSpecValidator(arg_spec, parameters)
passed = v.validate()
v = ArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
assert v.validated_parameters == expected
assert v.error_messages == []
assert passed is True
assert isinstance(result, ValidationResult)
assert result.validated_parameters == expected
assert result.unsupported_parameters == set()
assert result.error_messages == []
assert v._valid_parameter_names == valid_params
# Again to check caching
assert v._valid_parameter_names == valid_params

View file

@ -8,7 +8,7 @@ __metaclass__ = type
import pytest
from ansible.module_utils.common.parameters import get_unsupported_parameters
from ansible.module_utils.common.parameters import _get_unsupported_parameters
@pytest.fixture
@ -19,32 +19,6 @@ def argument_spec():
}
def mock_handle_aliases(*args):
aliases = {}
legal_inputs = [
'_ansible_check_mode',
'_ansible_debug',
'_ansible_diff',
'_ansible_keep_remote_files',
'_ansible_module_name',
'_ansible_no_log',
'_ansible_remote_tmp',
'_ansible_selinux_special_fs',
'_ansible_shell_executable',
'_ansible_socket',
'_ansible_string_conversion_action',
'_ansible_syslog_facility',
'_ansible_tmpdir',
'_ansible_verbosity',
'_ansible_version',
'state',
'status',
'enabled',
]
return aliases, legal_inputs
@pytest.mark.parametrize(
('module_parameters', 'legal_inputs', 'expected'),
(
@ -59,7 +33,6 @@ def mock_handle_aliases(*args):
)
)
def test_check_arguments(argument_spec, module_parameters, legal_inputs, expected, mocker):
mocker.patch('ansible.module_utils.common.parameters.handle_aliases', side_effect=mock_handle_aliases)
result = get_unsupported_parameters(argument_spec, module_parameters, legal_inputs)
result = _get_unsupported_parameters(argument_spec, module_parameters, legal_inputs)
assert result == expected

View file

@ -8,27 +8,9 @@ __metaclass__ = type
import pytest
from ansible.module_utils.common.parameters import handle_aliases
from ansible.module_utils.common.parameters import _handle_aliases
from ansible.module_utils._text import to_native
DEFAULT_LEGAL_INPUTS = [
'_ansible_check_mode',
'_ansible_debug',
'_ansible_diff',
'_ansible_keep_remote_files',
'_ansible_module_name',
'_ansible_no_log',
'_ansible_remote_tmp',
'_ansible_selinux_special_fs',
'_ansible_shell_executable',
'_ansible_socket',
'_ansible_string_conversion_action',
'_ansible_syslog_facility',
'_ansible_tmpdir',
'_ansible_verbosity',
'_ansible_version',
]
def test_handle_aliases_no_aliases():
argument_spec = {
@ -40,14 +22,9 @@ def test_handle_aliases_no_aliases():
'path': 'bar'
}
expected = (
{},
DEFAULT_LEGAL_INPUTS + ['name'],
)
expected[1].sort()
expected = {}
result = _handle_aliases(argument_spec, params)
result = handle_aliases(argument_spec, params)
result[1].sort()
assert expected == result
@ -63,14 +40,9 @@ def test_handle_aliases_basic():
'nick': 'foo',
}
expected = (
{'surname': 'name', 'nick': 'name'},
DEFAULT_LEGAL_INPUTS + ['name', 'surname', 'nick'],
)
expected[1].sort()
expected = {'surname': 'name', 'nick': 'name'}
result = _handle_aliases(argument_spec, params)
result = handle_aliases(argument_spec, params)
result[1].sort()
assert expected == result
@ -84,7 +56,7 @@ def test_handle_aliases_value_error():
}
with pytest.raises(ValueError) as ve:
handle_aliases(argument_spec, params)
_handle_aliases(argument_spec, params)
assert 'internal error: aliases must be a list or tuple' == to_native(ve.error)
@ -98,5 +70,5 @@ def test_handle_aliases_type_error():
}
with pytest.raises(TypeError) as te:
handle_aliases(argument_spec, params)
_handle_aliases(argument_spec, params)
assert 'internal error: required and default are mutually exclusive' in to_native(te.error)

View file

@ -7,7 +7,7 @@ __metaclass__ = type
import pytest
from ansible.module_utils.common.parameters import list_deprecations
from ansible.module_utils.common.parameters import _list_deprecations
@pytest.fixture
@ -33,7 +33,7 @@ def test_list_deprecations():
'foo': {'old': 'value'},
'bar': [{'old': 'value'}, {}],
}
result = list_deprecations(argument_spec, params)
result = _list_deprecations(argument_spec, params)
assert len(result) == 3
result.sort(key=lambda entry: entry['msg'])
assert result[0]['msg'] == """Param 'bar["old"]' is deprecated. See the module docs for more information"""

View file

@ -7,7 +7,7 @@ __metaclass__ = type
import pytest
from ansible.module_utils.common.parameters import list_no_log_values
from ansible.module_utils.common.parameters import _list_no_log_values
@pytest.fixture
@ -55,12 +55,12 @@ def test_list_no_log_values_no_secrets(module_parameters):
'value': {'type': 'int'},
}
expected = set()
assert expected == list_no_log_values(argument_spec, module_parameters)
assert expected == _list_no_log_values(argument_spec, module_parameters)
def test_list_no_log_values(argument_spec, module_parameters):
expected = set(('under', 'makeshift'))
assert expected == list_no_log_values(argument_spec(), module_parameters())
assert expected == _list_no_log_values(argument_spec(), module_parameters())
@pytest.mark.parametrize('extra_params', [
@ -81,7 +81,7 @@ def test_list_no_log_values_invalid_suboptions(argument_spec, module_parameters,
with pytest.raises(TypeError, match=r"(Value '.*?' in the sub parameter field '.*?' must by a dict, not '.*?')"
r"|(dictionary requested, could not parse JSON or key=value)"):
list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
_list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
def test_list_no_log_values_suboptions(argument_spec, module_parameters):
@ -103,7 +103,7 @@ def test_list_no_log_values_suboptions(argument_spec, module_parameters):
}
expected = set(('under', 'makeshift', 'bagel'))
assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
def test_list_no_log_values_sub_suboptions(argument_spec, module_parameters):
@ -136,7 +136,7 @@ def test_list_no_log_values_sub_suboptions(argument_spec, module_parameters):
}
expected = set(('under', 'makeshift', 'saucy', 'corporate'))
assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
def test_list_no_log_values_suboptions_list(argument_spec, module_parameters):
@ -164,7 +164,7 @@ def test_list_no_log_values_suboptions_list(argument_spec, module_parameters):
}
expected = set(('under', 'makeshift', 'playroom', 'luxury'))
assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
def test_list_no_log_values_sub_suboptions_list(argument_spec, module_parameters):
@ -204,7 +204,7 @@ def test_list_no_log_values_sub_suboptions_list(argument_spec, module_parameters
}
expected = set(('under', 'makeshift', 'playroom', 'luxury', 'basis', 'gave', 'composure', 'thumping'))
assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
@pytest.mark.parametrize('extra_params, expected', (
@ -225,4 +225,4 @@ def test_string_suboptions_as_string(argument_spec, module_parameters, extra_par
result = set(('under', 'makeshift'))
result.update(expected)
assert result == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
assert result == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))