Only pass kwargs to our string checker not callable checkers (#70151)

Since only check_type_str() accepts extra param, only pass to our checker and
do not pass kwargs to custom checkers.

* Add unit tests
This commit is contained in:
Sam Doran 2020-06-19 09:52:05 -04:00 committed by GitHub
parent 87406890cf
commit bc05415109
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 3 deletions

View file

@ -0,0 +1,4 @@
bugfixes:
- >-
if the ``type`` for a module parameter in the argument spec is callable,
do not pass ``kwargs`` to avoid errors (https://github.com/ansible/ansible/issues/70017)

View file

@ -1764,8 +1764,9 @@ class AnsibleModule(object):
type_checker, wanted_name = self._get_wanted_type(wanted, param) type_checker, wanted_name = self._get_wanted_type(wanted, param)
validated_params = [] validated_params = []
# Get param name for strings so we can later display this value in a useful error message if needed # Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {} kwargs = {}
if wanted_name == 'str': if wanted_name == 'str' and isinstance(wanted, string_types):
if isinstance(param, string_types): if isinstance(param, string_types):
kwargs['param'] = param kwargs['param'] = param
elif isinstance(param, dict): elif isinstance(param, dict):
@ -1800,8 +1801,9 @@ class AnsibleModule(object):
type_checker, wanted_name = self._get_wanted_type(wanted, k) type_checker, wanted_name = self._get_wanted_type(wanted, k)
# Get param name for strings so we can later display this value in a useful error message if needed # Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {} kwargs = {}
if wanted_name == 'str': if wanted_name == 'str' and isinstance(type_checker, string_types):
kwargs['param'] = list(param.keys())[0] kwargs['param'] = list(param.keys())[0]
# Get the name of the parent key if this is a nested option # Get the name of the parent key if this is a nested option

View file

@ -15,7 +15,7 @@ import pytest
from units.compat.mock import MagicMock from units.compat.mock import MagicMock
from ansible.module_utils import basic from ansible.module_utils import basic
from ansible.module_utils.common.warnings import get_deprecation_messages, get_warning_messages from ansible.module_utils.common.warnings import get_deprecation_messages, get_warning_messages
from ansible.module_utils.six import integer_types from ansible.module_utils.six import integer_types, string_types
from ansible.module_utils.six.moves import builtins from ansible.module_utils.six.moves import builtins
@ -101,6 +101,7 @@ def complex_argspec():
baz=dict(fallback=(basic.env_fallback, ['BAZ'])), baz=dict(fallback=(basic.env_fallback, ['BAZ'])),
bar1=dict(type='bool'), bar1=dict(type='bool'),
bar3=dict(type='list', elements='path'), bar3=dict(type='list', elements='path'),
bar_str=dict(type='list', elements=str),
zardoz=dict(choices=['one', 'two']), zardoz=dict(choices=['one', 'two']),
zardoz2=dict(type='list', choices=['one', 'two', 'three']), zardoz2=dict(type='list', choices=['one', 'two', 'three']),
zardoz3=dict(type='str', aliases=['zodraz'], deprecated_aliases=[dict(name='zodraz', version='9.99')]), zardoz3=dict(type='str', aliases=['zodraz'], deprecated_aliases=[dict(name='zodraz', version='9.99')]),
@ -212,6 +213,16 @@ def test_validator_function(mocker, stdin):
assert am.params['arg'] == 27 assert am.params['arg'] == 27
@pytest.mark.parametrize('stdin', [{'arg': '123'}, {'arg': 123}], indirect=['stdin'])
def test_validator_string_type(mocker, stdin):
# Custom callable that is 'str'
argspec = {'arg': {'type': str}}
am = basic.AnsibleModule(argspec)
assert isinstance(am.params['arg'], string_types)
assert am.params['arg'] == '123'
@pytest.mark.parametrize('argspec, expected, stdin', [(s[0], s[2], s[1]) for s in INVALID_SPECS], @pytest.mark.parametrize('argspec, expected, stdin', [(s[0], s[2], s[1]) for s in INVALID_SPECS],
indirect=['stdin']) indirect=['stdin'])
def test_validator_fail(stdin, capfd, argspec, expected): def test_validator_fail(stdin, capfd, argspec, expected):
@ -342,6 +353,16 @@ class TestComplexArgSpecs:
assert "Alias 'zodraz' is deprecated." in get_deprecation_messages()[0]['msg'] assert "Alias 'zodraz' is deprecated." in get_deprecation_messages()[0]['msg']
assert get_deprecation_messages()[0]['version'] == '9.99' assert get_deprecation_messages()[0]['version'] == '9.99'
@pytest.mark.parametrize('stdin', [{'foo': 'hello', 'bar_str': [867, '5309']}], indirect=['stdin'])
def test_list_with_elements_callable_str(self, capfd, mocker, stdin, complex_argspec):
"""Test choices with list"""
am = basic.AnsibleModule(**complex_argspec)
assert isinstance(am.params['bar_str'], list)
assert isinstance(am.params['bar_str'][0], string_types)
assert isinstance(am.params['bar_str'][1], string_types)
assert am.params['bar_str'][0] == '867'
assert am.params['bar_str'][1] == '5309'
class TestComplexOptions: class TestComplexOptions:
"""Test arg spec options""" """Test arg spec options"""