Move type checking methods out of basic.py and add unit tests (#53687)

* Move check_type_str() out of basic.py

* Move check_type_list() out of basic.py

* Move safe_eval() out of basic.py

* Move check_type_dict() out of basic.py

* Move json importing code to common location

* Move check_type_bool() out of basic.py

* Move _check_type_int() out of basic.py

* Move _check_type_float() out of basic.py

* Move _check_type_path() out of basic.py

* Move _check_type_raw() out of basic.py

* Move _check_type_bytes() out of basic.py

* Move _check_type_bits() out of basic.py

* Create text.formatters.py

Move human_to_bytes, bytes_to_human, and _lenient_lowercase out of basic.py into text.formatters.py
Change references in modules to point to function at new location

* Move _check_type_jsonarg() out of basic.py

* Rename json related functions and put them in common.text.converters

Move formatters.py to common.text.formatters.py and update references in modules.

* Rework check_type_str()

Add allow_conversion option to make the function more self-contained.
Move the messaging back to basic.py since those error messages are more relevant to using this function in the context of AnsibleModule and not when using the function in isolation.

* Add unit tests for type checking functions

* Change _lenient_lowercase to lenient_lowercase per feedback
This commit is contained in:
Sam Doran 2019-03-21 09:40:19 -04:00 committed by GitHub
parent bb61d7527f
commit ff88bd82b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 957 additions and 326 deletions

View file

@ -4,18 +4,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
SIZE_RANGES = {
'Y': 1 << 80,
'Z': 1 << 70,
'E': 1 << 60,
'P': 1 << 50,
'T': 1 << 40,
'G': 1 << 30,
'M': 1 << 20,
'K': 1 << 10,
'B': 1,
}
FILE_ATTRIBUTES = { FILE_ATTRIBUTES = {
'A': 'noatime', 'A': 'noatime',
'a': 'append', 'a': 'append',
@ -93,18 +81,27 @@ except ImportError:
# Python2 & 3 way to get NoneType # Python2 & 3 way to get NoneType
NoneType = type(None) NoneType = type(None)
from ansible.module_utils._text import to_native, to_bytes, to_text
from ansible.module_utils.common.text.converters import (
jsonify,
container_to_bytes as json_dict_unicode_to_bytes,
container_to_text as json_dict_bytes_to_unicode,
)
from ansible.module_utils.common.text.formatters import (
lenient_lowercase,
bytes_to_human,
human_to_bytes,
SIZE_RANGES,
)
try: try:
import json from ansible.module_utils.common._json_compat import json
# Detect the python-json library which is incompatible except ImportError as e:
try: print('\n{{"msg": "Error: ansible requires the stdlib json: {0}", "failed": true}}'.format(to_native(e)))
if not isinstance(json.loads, types.FunctionType) or not isinstance(json.dumps, types.FunctionType):
raise ImportError
except AttributeError:
raise ImportError
except ImportError:
print('\n{"msg": "Error: ansible requires the stdlib json and was not found!", "failed": true}')
sys.exit(1) sys.exit(1)
AVAILABLE_HASH_ALGORITHMS = dict() AVAILABLE_HASH_ALGORITHMS = dict()
try: try:
import hashlib import hashlib
@ -182,12 +179,22 @@ from ansible.module_utils.common.validation import (
check_required_one_of, check_required_one_of,
check_required_together, check_required_together,
count_terms, count_terms,
check_type_bool,
check_type_bits,
check_type_bytes,
check_type_float,
check_type_int,
check_type_jsonarg,
check_type_list,
check_type_dict,
check_type_path,
check_type_raw,
check_type_str,
safe_eval,
) )
from ansible.module_utils._text import to_native, to_bytes, to_text
from ansible.module_utils.common._utils import get_all_subclasses as _get_all_subclasses from ansible.module_utils.common._utils import get_all_subclasses as _get_all_subclasses
from ansible.module_utils.parsing.convert_bool import BOOLEANS, BOOLEANS_FALSE, BOOLEANS_TRUE, boolean from ansible.module_utils.parsing.convert_bool import BOOLEANS, BOOLEANS_FALSE, BOOLEANS_TRUE, boolean
# Note: When getting Sequence from collections, it matches with strings. If # Note: When getting Sequence from collections, it matches with strings. If
# this matters, make sure to check for strings before checking for sequencetype # this matters, make sure to check for strings before checking for sequencetype
SEQUENCETYPE = frozenset, KeysView, Sequence SEQUENCETYPE = frozenset, KeysView, Sequence
@ -306,45 +313,6 @@ def get_all_subclasses(cls):
# End compat shims # End compat shims
def json_dict_unicode_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'):
''' Recursively convert dict keys and values to byte str
Specialized for json return because this only handles, lists, tuples,
and dict container types (the containers that the json module returns)
'''
if isinstance(d, text_type):
return to_bytes(d, encoding=encoding, errors=errors)
elif isinstance(d, dict):
return dict(map(json_dict_unicode_to_bytes, iteritems(d), repeat(encoding), repeat(errors)))
elif isinstance(d, list):
return list(map(json_dict_unicode_to_bytes, d, repeat(encoding), repeat(errors)))
elif isinstance(d, tuple):
return tuple(map(json_dict_unicode_to_bytes, d, repeat(encoding), repeat(errors)))
else:
return d
def json_dict_bytes_to_unicode(d, encoding='utf-8', errors='surrogate_or_strict'):
''' Recursively convert dict keys and values to byte str
Specialized for json return because this only handles, lists, tuples,
and dict container types (the containers that the json module returns)
'''
if isinstance(d, binary_type):
# Warning, can traceback
return to_text(d, encoding=encoding, errors=errors)
elif isinstance(d, dict):
return dict(map(json_dict_bytes_to_unicode, iteritems(d), repeat(encoding), repeat(errors)))
elif isinstance(d, list):
return list(map(json_dict_bytes_to_unicode, d, repeat(encoding), repeat(errors)))
elif isinstance(d, tuple):
return tuple(map(json_dict_bytes_to_unicode, d, repeat(encoding), repeat(errors)))
else:
return d
def _remove_values_conditions(value, no_log_strings, deferred_removals): def _remove_values_conditions(value, no_log_strings, deferred_removals):
""" """
Helper function for :meth:`remove_values`. Helper function for :meth:`remove_values`.
@ -528,73 +496,6 @@ def heuristic_log_sanitize(data, no_log_values=None):
return output return output
def bytes_to_human(size, isbits=False, unit=None):
base = 'Bytes'
if isbits:
base = 'bits'
suffix = ''
for suffix, limit in sorted(iteritems(SIZE_RANGES), key=lambda item: -item[1]):
if (unit is None and size >= limit) or unit is not None and unit.upper() == suffix[0]:
break
if limit != 1:
suffix += base[0]
else:
suffix = base
return '%.2f %s' % (size / limit, suffix)
def human_to_bytes(number, default_unit=None, isbits=False):
'''
Convert number in string format into bytes (ex: '2K' => 2048) or using unit argument.
example: human_to_bytes('10M') <=> human_to_bytes(10, 'M')
'''
m = re.search(r'^\s*(\d*\.?\d*)\s*([A-Za-z]+)?', str(number), flags=re.IGNORECASE)
if m is None:
raise ValueError("human_to_bytes() can't interpret following string: %s" % str(number))
try:
num = float(m.group(1))
except Exception:
raise ValueError("human_to_bytes() can't interpret following number: %s (original input string: %s)" % (m.group(1), number))
unit = m.group(2)
if unit is None:
unit = default_unit
if unit is None:
''' No unit given, returning raw number '''
return int(round(num))
range_key = unit[0].upper()
try:
limit = SIZE_RANGES[range_key]
except Exception:
raise ValueError("human_to_bytes() failed to convert %s (unit = %s). The suffix must be one of %s" % (number, unit, ", ".join(SIZE_RANGES.keys())))
# default value
unit_class = 'B'
unit_class_name = 'byte'
# handling bits case
if isbits:
unit_class = 'b'
unit_class_name = 'bit'
# check unit value if more than one character (KB, MB)
if len(unit) > 1:
expect_message = 'expect %s%s or %s' % (range_key, unit_class, range_key)
if range_key == 'B':
expect_message = 'expect %s or %s' % (unit_class, unit_class_name)
if unit_class_name in unit.lower():
pass
elif unit[1] != unit_class:
raise ValueError("human_to_bytes() failed to convert %s. Value is not a valid string (%s)" % (number, expect_message))
return int(round(num * limit))
def _load_params(): def _load_params():
''' read the modules parameters and store them globally. ''' read the modules parameters and store them globally.
@ -659,44 +560,6 @@ def env_fallback(*args, **kwargs):
raise AnsibleFallbackNotFound raise AnsibleFallbackNotFound
def _lenient_lowercase(lst):
"""Lowercase elements of a list.
If an element is not a string, pass it through untouched.
"""
lowered = []
for value in lst:
try:
lowered.append(value.lower())
except AttributeError:
lowered.append(value)
return lowered
def _json_encode_fallback(obj):
if isinstance(obj, Set):
return list(obj)
elif isinstance(obj, datetime.datetime):
return obj.isoformat()
raise TypeError("Cannot json serialize %s" % to_native(obj))
def jsonify(data, **kwargs):
for encoding in ("utf-8", "latin-1"):
try:
return json.dumps(data, encoding=encoding, default=_json_encode_fallback, **kwargs)
# Old systems using old simplejson module does not support encoding keyword.
except TypeError:
try:
new_data = json_dict_bytes_to_unicode(data, encoding=encoding)
except UnicodeDecodeError:
continue
return json.dumps(new_data, default=_json_encode_fallback, **kwargs)
except UnicodeDecodeError:
continue
raise UnicodeError('Invalid unicode encoding encountered')
def missing_required_lib(library, reason=None, url=None): def missing_required_lib(library, reason=None, url=None):
hostname = platform.node() hostname = platform.node()
msg = "Failed to import the required Python library (%s) on %s's Python %s." % (library, hostname, sys.executable) msg = "Failed to import the required Python library (%s) on %s's Python %s." % (library, hostname, sys.executable)
@ -1714,7 +1577,7 @@ class AnsibleModule(object):
# the value. If we can't figure this out, module author is responsible. # the value. If we can't figure this out, module author is responsible.
lowered_choices = None lowered_choices = None
if param[k] == 'False': if param[k] == 'False':
lowered_choices = _lenient_lowercase(choices) lowered_choices = lenient_lowercase(choices)
overlap = BOOLEANS_FALSE.intersection(choices) overlap = BOOLEANS_FALSE.intersection(choices)
if len(overlap) == 1: if len(overlap) == 1:
# Extract from a set # Extract from a set
@ -1722,7 +1585,7 @@ class AnsibleModule(object):
if param[k] == 'True': if param[k] == 'True':
if lowered_choices is None: if lowered_choices is None:
lowered_choices = _lenient_lowercase(choices) lowered_choices = lenient_lowercase(choices)
overlap = BOOLEANS_TRUE.intersection(choices) overlap = BOOLEANS_TRUE.intersection(choices)
if len(overlap) == 1: if len(overlap) == 1:
(param[k],) = overlap (param[k],) = overlap
@ -1740,160 +1603,59 @@ class AnsibleModule(object):
self.fail_json(msg=msg) self.fail_json(msg=msg)
def safe_eval(self, value, locals=None, include_exceptions=False): def safe_eval(self, value, locals=None, include_exceptions=False):
return safe_eval(value, locals, include_exceptions)
# do not allow method calls to modules
if not isinstance(value, string_types):
# already templated to a datavaluestructure, perhaps?
if include_exceptions:
return (value, None)
return value
if re.search(r'\w\.\w+\(', value):
if include_exceptions:
return (value, None)
return value
# do not allow imports
if re.search(r'import \w+', value):
if include_exceptions:
return (value, None)
return value
try:
result = literal_eval(value)
if include_exceptions:
return (result, None)
else:
return result
except Exception as e:
if include_exceptions:
return (value, e)
return value
def _check_type_str(self, value): def _check_type_str(self, value):
if isinstance(value, string_types): opts = {
return value 'error': False,
'warn': False,
'ignore': True
}
# Ignore, warn, or error when converting to a string. # Ignore, warn, or error when converting to a string.
# The current default is to warn. Change this in Anisble 2.12 to error. allow_conversion = opts.get(self._string_conversion_action, True)
try:
return check_type_str(value, allow_conversion)
except TypeError:
common_msg = 'quote the entire value to ensure it does not change.' common_msg = 'quote the entire value to ensure it does not change.'
if self._string_conversion_action == 'error': if self._string_conversion_action == 'error':
msg = common_msg.capitalize() msg = common_msg.capitalize()
raise TypeError(msg) raise TypeError(to_native(msg))
elif self._string_conversion_action == 'warn': elif self._string_conversion_action == 'warn':
msg = ('The value {0!r} (type {0.__class__.__name__}) in a string field was converted to {1!r} (type string). ' msg = ('The value {0!r} (type {0.__class__.__name__}) in a string field was converted to {1!r} (type string). '
'If this does not look like what you expect, {2}').format(value, to_text(value), common_msg) 'If this does not look like what you expect, {2}').format(value, to_text(value), common_msg)
self.warn(msg) self.warn(to_native(msg))
return to_native(value, errors='surrogate_or_strict') return to_native(value, errors='surrogate_or_strict')
def _check_type_list(self, value): def _check_type_list(self, value):
if isinstance(value, list): return check_type_list(value)
return value
if isinstance(value, string_types):
return value.split(",")
elif isinstance(value, int) or isinstance(value, float):
return [str(value)]
raise TypeError('%s cannot be converted to a list' % type(value))
def _check_type_dict(self, value): def _check_type_dict(self, value):
if isinstance(value, dict): return check_type_dict(value)
return value
if isinstance(value, string_types):
if value.startswith("{"):
try:
return json.loads(value)
except Exception:
(result, exc) = self.safe_eval(value, dict(), include_exceptions=True)
if exc is not None:
raise TypeError('unable to evaluate string as dictionary')
return result
elif '=' in value:
fields = []
field_buffer = []
in_quote = False
in_escape = False
for c in value.strip():
if in_escape:
field_buffer.append(c)
in_escape = False
elif c == '\\':
in_escape = True
elif not in_quote and c in ('\'', '"'):
in_quote = c
elif in_quote and in_quote == c:
in_quote = False
elif not in_quote and c in (',', ' '):
field = ''.join(field_buffer)
if field:
fields.append(field)
field_buffer = []
else:
field_buffer.append(c)
field = ''.join(field_buffer)
if field:
fields.append(field)
return dict(x.split("=", 1) for x in fields)
else:
raise TypeError("dictionary requested, could not parse JSON or key=value")
raise TypeError('%s cannot be converted to a dict' % type(value))
def _check_type_bool(self, value): def _check_type_bool(self, value):
if isinstance(value, bool): return check_type_bool(value)
return value
if isinstance(value, string_types) or isinstance(value, int):
return self.boolean(value)
raise TypeError('%s cannot be converted to a bool' % type(value))
def _check_type_int(self, value): def _check_type_int(self, value):
if isinstance(value, integer_types): return check_type_int(value)
return value
if isinstance(value, string_types):
return int(value)
raise TypeError('%s cannot be converted to an int' % type(value))
def _check_type_float(self, value): def _check_type_float(self, value):
if isinstance(value, float): return check_type_float(value)
return value
if isinstance(value, (binary_type, text_type, int)):
return float(value)
raise TypeError('%s cannot be converted to a float' % type(value))
def _check_type_path(self, value): def _check_type_path(self, value):
value = self._check_type_str(value) return check_type_path(value)
return os.path.expanduser(os.path.expandvars(value))
def _check_type_jsonarg(self, value): def _check_type_jsonarg(self, value):
# Return a jsonified string. Sometimes the controller turns a json return check_type_jsonarg(value)
# string into a dict/list so transform it back into json here
if isinstance(value, (text_type, binary_type)):
return value.strip()
else:
if isinstance(value, (list, tuple, dict)):
return self.jsonify(value)
raise TypeError('%s cannot be converted to a json string' % type(value))
def _check_type_raw(self, value): def _check_type_raw(self, value):
return value return check_type_raw(value)
def _check_type_bytes(self, value): def _check_type_bytes(self, value):
try: return check_type_bytes(value)
self.human_to_bytes(value)
except ValueError:
raise TypeError('%s cannot be converted to a Byte value' % type(value))
def _check_type_bits(self, value): def _check_type_bits(self, value):
try: return check_type_bits(value)
self.human_to_bytes(value, isbits=True)
except ValueError:
raise TypeError('%s cannot be converted to a Bit value' % type(value))
def _handle_options(self, argument_spec=None, params=None): def _handle_options(self, argument_spec=None, params=None):
''' deal with options to create sub spec ''' ''' deal with options to create sub spec '''

View file

@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import types
import json
# Detect the python-json library which is incompatible
try:
if not isinstance(json.loads, types.FunctionType) or not isinstance(json.dumps, types.FunctionType):
raise ImportError('json.loads or json.dumps were not found in the imported json library.')
except AttributeError:
raise ImportError('python-json was detected, which is incompatible.')

View file

@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import datetime
import json
from itertools import repeat
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.module_utils.common._collections_compat import Set
from ansible.module_utils.six import (
binary_type,
iteritems,
text_type,
)
from ansible.module_utils.six.moves import map
def _json_encode_fallback(obj):
if isinstance(obj, Set):
return list(obj)
elif isinstance(obj, datetime.datetime):
return obj.isoformat()
raise TypeError("Cannot json serialize %s" % to_native(obj))
def jsonify(data, **kwargs):
for encoding in ("utf-8", "latin-1"):
try:
return json.dumps(data, encoding=encoding, default=_json_encode_fallback, **kwargs)
# Old systems using old simplejson module does not support encoding keyword.
except TypeError:
try:
new_data = container_to_text(data, encoding=encoding)
except UnicodeDecodeError:
continue
return json.dumps(new_data, default=_json_encode_fallback, **kwargs)
except UnicodeDecodeError:
continue
raise UnicodeError('Invalid unicode encoding encountered')
def container_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'):
''' Recursively convert dict keys and values to byte str
Specialized for json return because this only handles, lists, tuples,
and dict container types (the containers that the json module returns)
'''
if isinstance(d, text_type):
return to_bytes(d, encoding=encoding, errors=errors)
elif isinstance(d, dict):
return dict(map(container_to_bytes, iteritems(d), repeat(encoding), repeat(errors)))
elif isinstance(d, list):
return list(map(container_to_bytes, d, repeat(encoding), repeat(errors)))
elif isinstance(d, tuple):
return tuple(map(container_to_bytes, d, repeat(encoding), repeat(errors)))
else:
return d
def container_to_text(d, encoding='utf-8', errors='surrogate_or_strict'):
"""Recursively convert dict keys and values to byte str
Specialized for json return because this only handles, lists, tuples,
and dict container types (the containers that the json module returns)
"""
if isinstance(d, binary_type):
# Warning, can traceback
return to_text(d, encoding=encoding, errors=errors)
elif isinstance(d, dict):
return dict(map(container_to_text, iteritems(d), repeat(encoding), repeat(errors)))
elif isinstance(d, list):
return list(map(container_to_text, d, repeat(encoding), repeat(errors)))
elif isinstance(d, tuple):
return tuple(map(container_to_text, d, repeat(encoding), repeat(errors)))
else:
return d

View file

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import re
from ansible.module_utils.six import iteritems
SIZE_RANGES = {
'Y': 1 << 80,
'Z': 1 << 70,
'E': 1 << 60,
'P': 1 << 50,
'T': 1 << 40,
'G': 1 << 30,
'M': 1 << 20,
'K': 1 << 10,
'B': 1,
}
def lenient_lowercase(lst):
"""Lowercase elements of a list.
If an element is not a string, pass it through untouched.
"""
lowered = []
for value in lst:
try:
lowered.append(value.lower())
except AttributeError:
lowered.append(value)
return lowered
def human_to_bytes(number, default_unit=None, isbits=False):
"""Convert number in string format into bytes (ex: '2K' => 2048) or using unit argument.
example: human_to_bytes('10M') <=> human_to_bytes(10, 'M')
"""
m = re.search(r'^\s*(\d*\.?\d*)\s*([A-Za-z]+)?', str(number), flags=re.IGNORECASE)
if m is None:
raise ValueError("human_to_bytes() can't interpret following string: %s" % str(number))
try:
num = float(m.group(1))
except Exception:
raise ValueError("human_to_bytes() can't interpret following number: %s (original input string: %s)" % (m.group(1), number))
unit = m.group(2)
if unit is None:
unit = default_unit
if unit is None:
''' No unit given, returning raw number '''
return int(round(num))
range_key = unit[0].upper()
try:
limit = SIZE_RANGES[range_key]
except Exception:
raise ValueError("human_to_bytes() failed to convert %s (unit = %s). The suffix must be one of %s" % (number, unit, ", ".join(SIZE_RANGES.keys())))
# default value
unit_class = 'B'
unit_class_name = 'byte'
# handling bits case
if isbits:
unit_class = 'b'
unit_class_name = 'bit'
# check unit value if more than one character (KB, MB)
if len(unit) > 1:
expect_message = 'expect %s%s or %s' % (range_key, unit_class, range_key)
if range_key == 'B':
expect_message = 'expect %s or %s' % (unit_class, unit_class_name)
if unit_class_name in unit.lower():
pass
elif unit[1] != unit_class:
raise ValueError("human_to_bytes() failed to convert %s. Value is not a valid string (%s)" % (number, expect_message))
return int(round(num * limit))
def bytes_to_human(size, isbits=False, unit=None):
base = 'Bytes'
if isbits:
base = 'bits'
suffix = ''
for suffix, limit in sorted(iteritems(SIZE_RANGES), key=lambda item: -item[1]):
if (unit is None and size >= limit) or unit is not None and unit.upper() == suffix[0]:
break
if limit != 1:
suffix += base[0]
else:
suffix = base
return '%.2f %s' % (size / limit, suffix)

View file

@ -1,13 +1,26 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright (c) 2018 Ansible Project # Copyright (c) 2019 Ansible Project
# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) # Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
__metaclass__ = type __metaclass__ = type
from ansible.module_utils._text import to_native import os
import re
from ansible.module_utils._text import to_native, to_text
from ansible.module_utils.common._json_compat import json
from ansible.module_utils.common.collections import is_iterable from ansible.module_utils.common.collections import is_iterable
from ansible.module_utils.six import string_types from ansible.module_utils.common.text.converters import jsonify
from ansible.module_utils.common.text.formatters import human_to_bytes
from ansible.module_utils.parsing.convert_bool import boolean
from ansible.module_utils.pycompat24 import literal_eval
from ansible.module_utils.six import (
binary_type,
integer_types,
string_types,
text_type,
)
def count_terms(terms, module_parameters): def count_terms(terms, module_parameters):
@ -27,8 +40,9 @@ def count_terms(terms, module_parameters):
def check_mutually_exclusive(terms, module_parameters): def check_mutually_exclusive(terms, module_parameters):
"""Check mutually exclusive terms against argument parameters. Accepts """Check mutually exclusive terms against argument parameters
a single list or list of lists that are groups of terms that should be
Accepts a single list or list of lists that are groups of terms that should be
mutually exclusive with one another mutually exclusive with one another
:arg terms: List of mutually exclusive module parameters :arg terms: List of mutually exclusive module parameters
@ -56,7 +70,9 @@ def check_mutually_exclusive(terms, module_parameters):
def check_required_one_of(terms, module_parameters): def check_required_one_of(terms, module_parameters):
"""Check each list of terms to ensure at least one exists in the given module """Check each list of terms to ensure at least one exists in the given module
parameters. Accepts a list of lists or tuples. parameters
Accepts a list of lists or tuples
:arg terms: List of lists of terms to check. For each list of terms, at :arg terms: List of lists of terms to check. For each list of terms, at
least one is required. least one is required.
@ -84,7 +100,9 @@ def check_required_one_of(terms, module_parameters):
def check_required_together(terms, module_parameters): def check_required_together(terms, module_parameters):
"""Check each list of terms to ensure every parameter in each list exists """Check each list of terms to ensure every parameter in each list exists
in the given module parameters. Accepts a list of lists or tuples. in the given module parameters
Accepts a list of lists or tuples
:arg terms: List of lists of terms to check. Each list should include :arg terms: List of lists of terms to check. Each list should include
parameters that are all required when at least one is specified parameters that are all required when at least one is specified
@ -114,8 +132,9 @@ def check_required_together(terms, module_parameters):
def check_required_by(requirements, module_parameters): def check_required_by(requirements, module_parameters):
"""For each key in requirements, check the corresponding list to see if they """For each key in requirements, check the corresponding list to see if they
exist in module_parameters. Accepts a single string or list of values for exist in module_parameters
each key.
Accepts a single string or list of values for each key
:arg requirements: Dictionary of requirements :arg requirements: Dictionary of requirements
:arg module_parameters: Dictionary of module parameters :arg module_parameters: Dictionary of module parameters
@ -149,9 +168,9 @@ def check_required_by(requirements, module_parameters):
def check_required_arguments(argument_spec, module_parameters): def check_required_arguments(argument_spec, module_parameters):
"""Check all paramaters in argument_spec and return a list of parameters """Check all paramaters in argument_spec and return a list of parameters
that are required by not present in module_parameters. that are required but not present in module_parameters
Raises AnsibleModuleParameterException if the check fails. Raises TypeError if the check fails
:arg argument_spec: Argument spec dicitionary containing all parameters :arg argument_spec: Argument spec dicitionary containing all parameters
and their specification and their specification
@ -177,9 +196,9 @@ def check_required_arguments(argument_spec, module_parameters):
def check_required_if(requirements, module_parameters): def check_required_if(requirements, module_parameters):
"""Check parameters that are conditionally required. """Check parameters that are conditionally required
Raises TypeError if the check fails. Raises TypeError if the check fails
:arg requirements: List of lists specifying a parameter, value, parameters :arg requirements: List of lists specifying a parameter, value, parameters
required when the given parameter is the specified value, and optionally required when the given parameter is the specified value, and optionally
@ -262,6 +281,8 @@ def check_missing_parameters(module_parameters, required_parameters=None):
"""This is for checking for required params when we can not check via """This is for checking for required params when we can not check via
argspec because we need more information than is simply given in the argspec. argspec because we need more information than is simply given in the argspec.
Raises TypeError if any required parameters are missing
:arg module_paramaters: Dictionary of module parameters :arg module_paramaters: Dictionary of module parameters
:arg required_parameters: List of parameters to look for in the given module :arg required_parameters: List of parameters to look for in the given module
parameters parameters
@ -281,3 +302,244 @@ def check_missing_parameters(module_parameters, required_parameters=None):
raise TypeError(to_native(msg)) raise TypeError(to_native(msg))
return missing_params return missing_params
def safe_eval(value, locals=None, include_exceptions=False):
# do not allow method calls to modules
if not isinstance(value, string_types):
# already templated to a datavaluestructure, perhaps?
if include_exceptions:
return (value, None)
return value
if re.search(r'\w\.\w+\(', value):
if include_exceptions:
return (value, None)
return value
# do not allow imports
if re.search(r'import \w+', value):
if include_exceptions:
return (value, None)
return value
try:
result = literal_eval(value)
if include_exceptions:
return (result, None)
else:
return result
except Exception as e:
if include_exceptions:
return (value, e)
return value
def check_type_str(value, allow_conversion=True):
"""Verify that the value is a string or convert to a string.
Since unexpected changes can sometimes happen when converting to a string,
``allow_conversion`` controls whether or not the value will be converted or a
TypeError will be raised if the value is not a string and would be converted
:arg value: Value to validate or convert to a string
:arg allow_conversion: Whether to convert the string and return it or raise
a TypeError
:returns: Original value if it is a string, the value converted to a string
if allow_conversion=True, or raises a TypeError if allow_conversion=False.
"""
if isinstance(value, string_types):
return value
if allow_conversion:
return to_native(value, errors='surrogate_or_strict')
msg = "'{0!r}' is not a string and conversion is not allowed".format(value)
raise TypeError(to_native(msg))
def check_type_list(value):
"""Verify that the value is a list or convert to a list
A comma separated string will be split into a list. Rases a TypeError if
unable to convert to a list.
:arg value: Value to validate or convert to a list
:returns: Original value if it is already a list, single item list if a
float, int or string without commas, or a multi-item list if a
comma-delimited string.
"""
if isinstance(value, list):
return value
if isinstance(value, string_types):
return value.split(",")
elif isinstance(value, int) or isinstance(value, float):
return [str(value)]
raise TypeError('%s cannot be converted to a list' % type(value))
def check_type_dict(value):
"""Verify that value is a dict or convert it to a dict and return it.
Raises TypeError if unable to convert to a dict
:arg value: Dict or string to convert to a dict. Accepts 'k1=v2, k2=v2'.
:returns: value converted to a dictionary
"""
if isinstance(value, dict):
return value
if isinstance(value, string_types):
if value.startswith("{"):
try:
return json.loads(value)
except Exception:
(result, exc) = safe_eval(value, dict(), include_exceptions=True)
if exc is not None:
raise TypeError('unable to evaluate string as dictionary')
return result
elif '=' in value:
fields = []
field_buffer = []
in_quote = False
in_escape = False
for c in value.strip():
if in_escape:
field_buffer.append(c)
in_escape = False
elif c == '\\':
in_escape = True
elif not in_quote and c in ('\'', '"'):
in_quote = c
elif in_quote and in_quote == c:
in_quote = False
elif not in_quote and c in (',', ' '):
field = ''.join(field_buffer)
if field:
fields.append(field)
field_buffer = []
else:
field_buffer.append(c)
field = ''.join(field_buffer)
if field:
fields.append(field)
return dict(x.split("=", 1) for x in fields)
else:
raise TypeError("dictionary requested, could not parse JSON or key=value")
raise TypeError('%s cannot be converted to a dict' % type(value))
def check_type_bool(value):
"""Verify that the value is a bool or convert it to a bool and return it.
Raises TypeError if unable to convert to a bool
:arg value: String, int, or float to convert to bool. Valid booleans include:
'1', 'on', 1, '0', 0, 'n', 'f', 'false', 'true', 'y', 't', 'yes', 'no', 'off'
:returns: Boolean True or False
"""
if isinstance(value, bool):
return value
if isinstance(value, string_types) or isinstance(value, (int, float)):
return boolean(value)
raise TypeError('%s cannot be converted to a bool' % type(value))
def check_type_int(value):
"""Verify that the value is an integer and return it or convert the value
to an integer and return it
Raises TypeError if unable to convert to an int
:arg value: String or int to convert of verify
:return: Int of given value
"""
if isinstance(value, integer_types):
return value
if isinstance(value, string_types):
try:
return int(value)
except ValueError:
pass
raise TypeError('%s cannot be converted to an int' % type(value))
def check_type_float(value):
"""Verify that value is a float or convert it to a float and return it
Raises TypeError if unable to convert to a float
:arg value: Float, int, str, or bytes to verify or convert and return.
:returns: Float of given value.
"""
if isinstance(value, float):
return value
if isinstance(value, (binary_type, text_type, int)):
try:
return float(value)
except ValueError:
pass
raise TypeError('%s cannot be converted to a float' % type(value))
def check_type_path(value,):
"""Verify the provided value is a string or convert it to a string,
then return the expanded path
"""
value = check_type_str(value)
return os.path.expanduser(os.path.expandvars(value))
def check_type_raw(value):
"""Returns the raw value
"""
return value
def check_type_bytes(value):
"""Convert a human-readable string value to bytes
Raises TypeError if unable to covert the value
"""
try:
return human_to_bytes(value)
except ValueError:
raise TypeError('%s cannot be converted to a Byte value' % type(value))
def check_type_bits(value):
"""Convert a human-readable string value to bits
Raises TypeError if unable to covert the value
"""
try:
return human_to_bytes(value, isbits=True)
except ValueError:
raise TypeError('%s cannot be converted to a Bit value' % type(value))
def check_type_jsonarg(value):
"""Return a jsonified string. Sometimes the controller turns a json string
into a dict/list so transform it back into json here
Raises TypeError if unable to covert the value
"""
if isinstance(value, (text_type, binary_type)):
return value.strip()
elif isinstance(value, (list, tuple, dict)):
return jsonify(value)
raise TypeError('%s cannot be converted to a json string' % type(value))

View file

@ -30,7 +30,7 @@ from multiprocessing.pool import ThreadPool
from ansible.module_utils._text import to_text from ansible.module_utils._text import to_text
from ansible.module_utils.six import iteritems from ansible.module_utils.six import iteritems
from ansible.module_utils.basic import bytes_to_human from ansible.module_utils.common.text.formatters import bytes_to_human
from ansible.module_utils.facts.hardware.base import Hardware, HardwareCollector from ansible.module_utils.facts.hardware.base import Hardware, HardwareCollector
from ansible.module_utils.facts.utils import get_file_content, get_file_lines, get_mount_size from ansible.module_utils.facts.utils import get_file_content, get_file_lines, get_mount_size

View file

@ -20,7 +20,7 @@ import re
from ansible.module_utils.six.moves import reduce from ansible.module_utils.six.moves import reduce
from ansible.module_utils.basic import bytes_to_human from ansible.module_utils.common.text.formatters import bytes_to_human
from ansible.module_utils.facts.utils import get_file_content, get_mount_size from ansible.module_utils.facts.utils import get_file_content, get_mount_size

View file

@ -927,7 +927,7 @@ import re
import shlex import shlex
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ansible.module_utils.basic import human_to_bytes from ansible.module_utils.common.text.formatters import human_to_bytes
from ansible.module_utils.docker.common import ( from ansible.module_utils.docker.common import (
AnsibleDockerClient, AnsibleDockerClient,
DifferenceTracker, DifferenceTracker,

View file

@ -120,7 +120,8 @@ ansible_facts:
} }
''' '''
from ansible.module_utils.basic import AnsibleModule, bytes_to_human from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.formatters import bytes_to_human
from ansible.module_utils.vmware import PyVmomi, vmware_argument_spec, find_obj from ansible.module_utils.vmware import PyVmomi, vmware_argument_spec, find_obj
try: try:

View file

@ -29,7 +29,7 @@ import math
from jinja2.filters import environmentfilter from jinja2.filters import environmentfilter
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible.module_utils import basic from ansible.module_utils.common.text import formatters
from ansible.module_utils.six import binary_type, text_type from ansible.module_utils.six import binary_type, text_type
from ansible.module_utils.six.moves import zip, zip_longest from ansible.module_utils.six.moves import zip, zip_longest
from ansible.module_utils.common._collections_compat import Hashable, Mapping, Iterable from ansible.module_utils.common._collections_compat import Hashable, Mapping, Iterable
@ -163,7 +163,7 @@ def inversepower(x, base=2):
def human_readable(size, isbits=False, unit=None): def human_readable(size, isbits=False, unit=None):
''' Return a human readable string ''' ''' Return a human readable string '''
try: try:
return basic.bytes_to_human(size, isbits, unit) return formatters.bytes_to_human(size, isbits, unit)
except Exception: except Exception:
raise AnsibleFilterError("human_readable() can't interpret following string: %s" % size) raise AnsibleFilterError("human_readable() can't interpret following string: %s" % size)
@ -171,7 +171,7 @@ def human_readable(size, isbits=False, unit=None):
def human_to_bytes(size, default_unit=None, isbits=False): def human_to_bytes(size, default_unit=None, isbits=False):
''' Return bytes count from a human readable string ''' ''' Return bytes count from a human readable string '''
try: try:
return basic.human_to_bytes(size, default_unit, isbits) return formatters.human_to_bytes(size, default_unit, isbits)
except Exception: except Exception:
raise AnsibleFilterError("human_to_bytes() can't interpret following string: %s" % size) raise AnsibleFilterError("human_to_bytes() can't interpret following string: %s" % size)

View file

@ -164,8 +164,8 @@
- name: assert assume role with invalid duration seconds - name: assert assume role with invalid duration seconds
assert: assert:
that: that:
- 'result.failed' - result is failed
- "'unable to convert to int: invalid literal for int()' in result.msg" - 'result.msg is search("argument \w+ is of type <.*> and we were unable to convert to int: <.*> cannot be converted to an int")'
# ============================================================ # ============================================================
- name: test assume role with invalid external id - name: test assume role with invalid external id

View file

@ -44,12 +44,15 @@ MODULE_UTILS_BASIC_IMPORTS = frozenset((('_text',),
('basic',), ('basic',),
('common', '__init__'), ('common', '__init__'),
('common', '_collections_compat'), ('common', '_collections_compat'),
('common', '_json_compat'),
('common', 'collections'), ('common', 'collections'),
('common', 'file'), ('common', 'file'),
('common', 'collections'),
('common', 'parameters'), ('common', 'parameters'),
('common', 'process'), ('common', 'process'),
('common', 'sys_info'), ('common', 'sys_info'),
('common', 'text', '__init__'),
('common', 'text', 'converters'),
('common', 'text', 'formatters'),
('common', 'validation'), ('common', 'validation'),
('common', '_utils'), ('common', '_utils'),
('distro', '__init__'), ('distro', '__init__'),
@ -65,6 +68,7 @@ MODULE_UTILS_BASIC_FILES = frozenset(('ansible/module_utils/_text.py',
'ansible/module_utils/six/__init__.py', 'ansible/module_utils/six/__init__.py',
'ansible/module_utils/_text.py', 'ansible/module_utils/_text.py',
'ansible/module_utils/common/_collections_compat.py', 'ansible/module_utils/common/_collections_compat.py',
'ansible/module_utils/common/_json_compat.py',
'ansible/module_utils/common/collections.py', 'ansible/module_utils/common/collections.py',
'ansible/module_utils/common/parameters.py', 'ansible/module_utils/common/parameters.py',
'ansible/module_utils/parsing/convert_bool.py', 'ansible/module_utils/parsing/convert_bool.py',
@ -72,6 +76,9 @@ MODULE_UTILS_BASIC_FILES = frozenset(('ansible/module_utils/_text.py',
'ansible/module_utils/common/file.py', 'ansible/module_utils/common/file.py',
'ansible/module_utils/common/process.py', 'ansible/module_utils/common/process.py',
'ansible/module_utils/common/sys_info.py', 'ansible/module_utils/common/sys_info.py',
'ansible/module_utils/common/text/__init__.py',
'ansible/module_utils/common/text/converters.py',
'ansible/module_utils/common/text/formatters.py',
'ansible/module_utils/common/validation.py', 'ansible/module_utils/common/validation.py',
'ansible/module_utils/common/_utils.py', 'ansible/module_utils/common/_utils.py',
'ansible/module_utils/distro/__init__.py', 'ansible/module_utils/distro/__init__.py',

View file

@ -69,9 +69,10 @@ VALID_SPECS = (
INVALID_SPECS = ( INVALID_SPECS = (
# Type is int; unable to convert this string # Type is int; unable to convert this string
({'arg': {'type': 'int'}}, {'arg': "bad"}, "invalid literal for int() with base 10: 'bad'"), ({'arg': {'type': 'int'}}, {'arg': "wolf"}, "is of type {0} and we were unable to convert to int: {0} cannot be converted to an int".format(type('bad'))),
# Type is list elements is int; unable to convert this string # Type is list elements is int; unable to convert this string
({'arg': {'type': 'list', 'elements': 'int'}}, {'arg': [1, "bad"]}, "invalid literal for int() with base 10: 'bad'"), ({'arg': {'type': 'list', 'elements': 'int'}}, {'arg': [1, "bad"]}, "is of type {0} and we were unable to convert to int: {0} cannot be converted to "
"an int".format(type('int'))),
# Type is int; unable to convert float # Type is int; unable to convert float
({'arg': {'type': 'int'}}, {'arg': 42.1}, "'float'> cannot be converted to an int"), ({'arg': {'type': 'int'}}, {'arg': 42.1}, "'float'> cannot be converted to an int"),
# Type is list, elements is int; unable to convert float # Type is list, elements is int; unable to convert float

View file

@ -64,13 +64,12 @@ class TestImports(ModuleTestCase):
@patch.object(builtins, '__import__') @patch.object(builtins, '__import__')
def test_module_utils_basic_import_json(self, mock_import): def test_module_utils_basic_import_json(self, mock_import):
def _mock_import(name, *args, **kwargs): def _mock_import(name, *args, **kwargs):
if name == 'json': if name == 'ansible.module_utils.common._json_compat':
raise ImportError raise ImportError
return realimport(name, *args, **kwargs) return realimport(name, *args, **kwargs)
self.clear_modules(['json', 'ansible.module_utils.basic']) self.clear_modules(['json', 'ansible.module_utils.basic'])
builtins.__import__('ansible.module_utils.basic') builtins.__import__('ansible.module_utils.basic')
self.clear_modules(['json', 'ansible.module_utils.basic']) self.clear_modules(['json', 'ansible.module_utils.basic'])
mock_import.side_effect = _mock_import mock_import.side_effect = _mock_import
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):

View file

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils._text import to_native
from ansible.module_utils.common.validation import check_type_bits
def test_check_type_bits():
test_cases = (
('1', 1),
(99, 99),
(1.5, 2),
('1.5', 2),
('2b', 2),
('2k', 2048),
('2K', 2048),
('1m', 1048576),
('1M', 1048576),
('1g', 1073741824),
('1G', 1073741824),
(1073741824, 1073741824),
)
for case in test_cases:
assert case[1] == check_type_bits(case[0])
def test_check_type_bits_fail():
test_cases = (
'foo',
'2KB',
'1MB',
'1GB',
)
for case in test_cases:
with pytest.raises(TypeError) as e:
check_type_bits(case)
assert 'cannot be converted to a Bit value' in to_native(e.value)

View file

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils._text import to_native
from ansible.module_utils.common.validation import check_type_bool
def test_check_type_bool():
test_cases = (
(True, True),
(False, False),
('1', True),
('on', True),
(1, True),
('0', False),
(0, False),
('n', False),
('f', False),
('false', False),
('true', True),
('y', True),
('t', True),
('yes', True),
('no', False),
('off', False),
)
for case in test_cases:
assert case[1] == check_type_bool(case[0])
def test_check_type_bool_fail():
default_test_msg = 'cannot be converted to a bool'
test_cases = (
({'k1': 'v1'}, 'is not a valid bool'),
(3.14159, default_test_msg),
(-1, default_test_msg),
(-90810398401982340981023948192349081, default_test_msg),
(90810398401982340981023948192349081, default_test_msg),
)
for case in test_cases:
with pytest.raises(TypeError) as e:
check_type_bool(case)
assert 'cannot be converted to a bool' in to_native(e.value)

View file

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils._text import to_native
from ansible.module_utils.common.validation import check_type_bytes
def test_check_type_bytes():
test_cases = (
('1', 1),
(99, 99),
(1.5, 2),
('1.5', 2),
('2b', 2),
('2B', 2),
('2k', 2048),
('2K', 2048),
('2KB', 2048),
('1m', 1048576),
('1M', 1048576),
('1MB', 1048576),
('1g', 1073741824),
('1G', 1073741824),
('1GB', 1073741824),
(1073741824, 1073741824),
)
for case in test_cases:
assert case[1] == check_type_bytes(case[0])
def test_check_type_bytes_fail():
test_cases = (
'foo',
'2kb',
'2Kb',
'1mb',
'1Mb',
'1gb',
'1Gb',
)
for case in test_cases:
with pytest.raises(TypeError) as e:
check_type_bytes(case)
assert 'cannot be converted to a Byte value' in to_native(e.value)

View file

@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils.common.validation import check_type_dict
def test_check_type_dict():
test_cases = (
({'k1': 'v1'}, {'k1': 'v1'}),
('k1=v1,k2=v2', {'k1': 'v1', 'k2': 'v2'}),
('k1=v1, k2=v2', {'k1': 'v1', 'k2': 'v2'}),
('k1=v1, k2=v2, k3=v3', {'k1': 'v1', 'k2': 'v2', 'k3': 'v3'}),
('{"key": "value", "list": ["one", "two"]}', {'key': 'value', 'list': ['one', 'two']})
)
for case in test_cases:
assert case[1] == check_type_dict(case[0])
def test_check_type_dict_fail():
test_cases = (
1,
3.14159,
[1, 2],
'a',
)
for case in test_cases:
with pytest.raises(TypeError):
check_type_dict(case)

View file

@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils._text import to_native
from ansible.module_utils.common.validation import check_type_float
def test_check_type_float():
test_cases = (
('1.5', 1.5),
('''1.5''', 1.5),
(u'1.5', 1.5),
(1002, 1002.0),
(1.0, 1.0),
(3.141592653589793, 3.141592653589793),
('3.141592653589793', 3.141592653589793),
(b'3.141592653589793', 3.141592653589793),
)
for case in test_cases:
assert case[1] == check_type_float(case[0])
def test_check_type_float_fail():
test_cases = (
{'k1': 'v1'},
['a', 'b'],
'b',
)
for case in test_cases:
with pytest.raises(TypeError) as e:
check_type_float(case)
assert 'cannot be converted to a float' in to_native(e.value)

View file

@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils._text import to_native
from ansible.module_utils.common.validation import check_type_int
def test_check_type_int():
test_cases = (
('1', 1),
(u'1', 1),
(1002, 1002),
)
for case in test_cases:
assert case[1] == check_type_int(case[0])
def test_check_type_int_fail():
test_cases = (
{'k1': 'v1'},
(b'1', 1),
(3.14159, 3),
'b',
)
for case in test_cases:
with pytest.raises(TypeError) as e:
check_type_int(case)
assert 'cannot be converted to an int' in to_native(e)

View file

@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils._text import to_native
from ansible.module_utils.common.validation import check_type_jsonarg
def test_check_type_jsonarg():
test_cases = (
('a', 'a'),
('a ', 'a'),
(b'99', b'99'),
(b'99 ', b'99'),
({'k1': 'v1'}, '{"k1": "v1"}'),
([1, 'a'], '[1, "a"]'),
((1, 2, 'three'), '[1, 2, "three"]'),
)
for case in test_cases:
assert case[1] == check_type_jsonarg(case[0])
def test_check_type_jsonarg_fail():
test_cases = (
1.5,
910313498012384012341982374109384098,
)
for case in test_cases:
with pytest.raises(TypeError) as e:
check_type_jsonarg(case)
assert 'cannot be converted to a json string' in to_native(e.value)

View file

@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils.common.validation import check_type_list
def test_check_type_list():
test_cases = (
([1, 2], [1, 2]),
(1, ['1']),
(['a', 'b'], ['a', 'b']),
('a', ['a']),
(3.14159, ['3.14159']),
('a,b,1,2', ['a', 'b', '1', '2'])
)
for case in test_cases:
assert case[1] == check_type_list(case[0])
def test_check_type_list_failure():
test_cases = (
{'k1': 'v1'},
)
for case in test_cases:
with pytest.raises(TypeError):
check_type_list(case)

View file

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import re
import os
from ansible.module_utils.common.validation import check_type_path
def mock_expand(value):
return re.sub(r'~|\$HOME', '/home/testuser', value)
def test_check_type_path(monkeypatch):
monkeypatch.setattr(os.path, 'expandvars', mock_expand)
monkeypatch.setattr(os.path, 'expanduser', mock_expand)
test_cases = (
('~/foo', '/home/testuser/foo'),
('$HOME/foo', '/home/testuser/foo'),
('/home/jane', '/home/jane'),
(u'/home/jané', u'/home/jané'),
)
for case in test_cases:
assert case[1] == check_type_path(case[0])

View file

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from ansible.module_utils.common.validation import check_type_raw
def test_check_type_raw():
test_cases = (
(1, 1),
('1', '1'),
('a', 'a'),
({'k1': 'v1'}, {'k1': 'v1'}),
([1, 2], [1, 2]),
(b'42', b'42'),
(u'42', u'42'),
)
for case in test_cases:
assert case[1] == check_type_raw(case[0])

View file

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils._text import to_native
from ansible.module_utils.common.validation import check_type_str
TEST_CASES = (
('string', 'string'),
(100, '100'),
(1.5, '1.5'),
({'k1': 'v1'}, "{'k1': 'v1'}"),
([1, 2, 'three'], "[1, 2, 'three']"),
((1, 2,), '(1, 2)'),
)
@pytest.mark.parametrize('value, expected', TEST_CASES)
def test_check_type_str(value, expected):
assert expected == check_type_str(value)
@pytest.mark.parametrize('value, expected', TEST_CASES[1:])
def test_check_type_str_no_conversion(value, expected):
with pytest.raises(TypeError) as e:
check_type_str(value, allow_conversion=False)
assert 'is not a string and conversion is not allowed' in to_native(e.value)