adds new common functions for declarative intent modules (#25210)

* adds new common functions for declarative intent modules

* adds Entity and EntityCollection
* adds dict_diff and dict_combine

* update for CI  PEP8 compliance

* more CI PEP8 fixes

* more PEP8 CI clean up

* refactors the lambda assignments into top level classes

this is to be in compliant the PEP8 CI sanity checks

* one last pep8 ci fix
This commit is contained in:
Peter Sprygada 2017-06-16 10:16:20 -04:00 committed by GitHub
parent 43468b825d
commit 3aa41eda0b
2 changed files with 287 additions and 26 deletions

View file

@ -24,7 +24,10 @@
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from itertools import chain
from ansible.module_utils.six import iteritems
from ansible.module_utils.basic import AnsibleFallbackNotFound from ansible.module_utils.basic import AnsibleFallbackNotFound
from ansible.module_utils.six import iteritems from ansible.module_utils.six import iteritems
@ -38,7 +41,13 @@ def to_list(val):
return list() return list()
class ComplexDict(object): def sort_list(val):
if isinstance(val, list):
return sorted(val)
return val
class Entity(object):
"""Transforms a dict to with an argument spec """Transforms a dict to with an argument spec
This class will take a dict and apply an Ansible argument spec to the This class will take a dict and apply an Ansible argument spec to the
@ -52,7 +61,7 @@ class ComplexDict(object):
display=dict(default='text', choices=['text', 'json']), display=dict(default='text', choices=['text', 'json']),
validate=dict(type='bool') validate=dict(type='bool')
) )
transform = ComplexDict(argument_spec, module) transform = Entity(module, argument_spec)
value = dict(command='foo') value = dict(command='foo')
result = transform(value) result = transform(value)
print result print result
@ -66,31 +75,42 @@ class ComplexDict(object):
* fallback - implements fallback function * fallback - implements fallback function
* choices - set of valid options * choices - set of valid options
* default - default value * default - default value
""" """
def __init__(self, attrs, module): def __init__(self, module, attrs=None, args=[], keys=None, from_argspec=False):
self._attributes = attrs self._attributes = attrs or {}
self._module = module self._module = module
for arg in args:
self._attributes[arg] = dict()
if from_argspec:
self._attributes[arg]['read_from'] = arg
if keys and arg in keys:
self._attributes[arg]['key'] = True
self.attr_names = frozenset(self._attributes.keys()) self.attr_names = frozenset(self._attributes.keys())
self._has_key = False _has_key = False
for name, attr in iteritems(self._attributes): for name, attr in iteritems(self._attributes):
if attr.get('read_from'): if attr.get('read_from'):
if attr['read_from'] not in self._module.argument_spec:
module.fail_json(msg='argument %s does not exist' % attr['read_from'])
spec = self._module.argument_spec.get(attr['read_from']) spec = self._module.argument_spec.get(attr['read_from'])
if not spec:
raise ValueError('argument_spec %s does not exist' % attr['read_from'])
for key, value in iteritems(spec): for key, value in iteritems(spec):
if key not in attr: if key not in attr:
attr[key] = value attr[key] = value
if attr.get('key'): if attr.get('key'):
if self._has_key: if _has_key:
raise ValueError('only one key value can be specified') module.fail_json(msg='only one key value can be specified')
self._has_key = True _has_key = True
attr['required'] = True attr['required'] = True
def _dict(self, value): def serialize(self):
return self._attributes
def to_dict(self, value):
obj = {} obj = {}
for name, attr in iteritems(self._attributes): for name, attr in iteritems(self._attributes):
if attr.get('key'): if attr.get('key'):
@ -99,16 +119,17 @@ class ComplexDict(object):
obj[name] = attr.get('default') obj[name] = attr.get('default')
return obj return obj
def __call__(self, value): def __call__(self, value, strict=True):
if not isinstance(value, dict): if not isinstance(value, dict):
value = self._dict(value) value = self.to_dict(value)
if strict:
unknown = set(value).difference(self.attr_names) unknown = set(value).difference(self.attr_names)
if unknown: if unknown:
raise ValueError('invalid keys: %s' % ','.join(unknown)) self._module.fail_json(msg='invalid keys: %s' % ','.join(unknown))
for name, attr in iteritems(self._attributes): for name, attr in iteritems(self._attributes):
if not value.get(name): if value.get(name) is None:
value[name] = attr.get('default') value[name] = attr.get('default')
if attr.get('fallback') and not value.get(name): if attr.get('fallback') and not value.get(name):
@ -128,24 +149,135 @@ class ComplexDict(object):
continue continue
if attr.get('required') and value.get(name) is None: if attr.get('required') and value.get(name) is None:
raise ValueError('missing required attribute %s' % name) self._module.fail_json(msg='missing required attribute %s' % name)
if 'choices' in attr: if 'choices' in attr:
if value[name] not in attr['choices']: if value[name] not in attr['choices']:
raise ValueError('%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name])) self._module.fail_json(msg='%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name]))
if value[name] is not None: if value[name] is not None:
value_type = attr.get('type', 'str') value_type = attr.get('type', 'str')
type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type] type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type]
type_checker(value[name]) type_checker(value[name])
elif value.get(name):
value[name] = self._module.params[name]
return value return value
class ComplexList(ComplexDict): class EntityCollection(Entity):
"""Extends ```ComplexDict``` to handle a list of dicts """ """Extends ```Entity``` to handle a list of dicts """
def __call__(self, values): def __call__(self, iterable, strict=True):
if not isinstance(values, (list, tuple)): if iterable is None:
raise TypeError('value must be an ordered iterable') iterable = [super(EntityCollection, self).__call__(self._module.params, strict)]
return [(super(ComplexList, self).__call__(v)) for v in values]
if not isinstance(iterable, (list, tuple)):
module.fail_json(msg='value must be an iterable')
return [(super(EntityCollection, self).__call__(i, strict)) for i in iterable]
# these two are for backwards compatibility and can be removed once all of the
# modules that use them are updated
class ComplexDict(Entity):
def __init__(self, attrs, module, *args, **kwargs):
super(ComplexDict, self).__init__(module, attrs, *args, **kwargs)
class ComplexList(EntityCollection):
def __init__(self, attrs, module, *args, **kwargs):
super(ComplexList, self).__init__(module, attrs, *args, **kwargs)
def dict_diff(base, comparable):
""" Generate a dict object of differences
This function will compare two dict objects and return the difference
between them as a dict object. For scalar values, the key will reflect
the updated value. If the key does not exist in `comparable`, then then no
key will be returned. For lists, the value in comparable will wholly replace
the value in base for the key. For dicts, the returned value will only
return keys that are different.
:param base: dict object to base the diff on
:param comparable: dict object to compare against base
:returns: new dict object with differences
"""
assert isinstance(base, dict), "`base` must be of type <dict>"
assert isinstance(comparable, dict), "`comparable` must be of type <dict>"
updates = dict()
for key, value in iteritems(base):
if isinstance(value, dict):
item = comparable.get(key)
if item is not None:
updates[key] = dict_diff(value, comparable[key])
else:
comparable_value = comparable.get(key)
if comparable_value is not None:
if sort_list(base[key]) != sort_list(comparable_value):
updates[key] = comparable_value
for key in set(comparable.keys()).difference(base.keys()):
updates[key] = comparable.get(key)
return updates
def dict_combine(base, other):
""" Return a new dict object that combines base and other
This will create a new dict object that is a combination of the key/value
pairs from base and other. When both keys exist, the value will be
selected from other. If the value is a list object, the two lists will
be combined and duplicate entries removed.
:param base: dict object to serve as base
:param other: dict object to combine with base
:returns: new combined dict object
"""
assert isinstance(base, dict), "`base` must be of type <dict>"
assert isinstance(other, dict), "`other` must be of type <dict>"
combined = dict()
for key, value in iteritems(base):
if isinstance(value, dict):
if key in other:
item = other.get(key)
if item is not None:
combined[key] = dict_combine(value, other[key])
else:
combined[key] = item
else:
combined[key] = value
elif isinstance(value, list):
if key in other:
item = other.get(key)
if item is not None:
combined[key] = list(set(chain(value, item)))
else:
combined[key] = item
else:
combined[key] = value
else:
if key in other:
other_value = other.get(key)
if other_value is not None:
if sort_list(base[key]) != sort_list(other_value):
combined[key] = other_value
else:
combined[key] = value
else:
combined[key] = other_value
else:
combined[key] = value
for key in set(other.keys()).difference(base.keys()):
combined[key] = other.get(key)
return combined

View file

@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
#
# (c) 2017 Red Hat, Inc.
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division)
__metaclass__ = type
from ansible.compat.tests import unittest
from ansible.module_utils.network_common import to_list, sort_list
from ansible.module_utils.network_common import dict_diff, dict_combine
class TestModuleUtilsNetworkCommon(unittest.TestCase):
def test_to_list(self):
for scalar in ('string', 1, True, False, None):
self.assertTrue(isinstance(to_list(scalar), list))
for container in ([1, 2, 3], {'one': 1}):
self.assertTrue(isinstance(to_list(container), list))
test_list = [1, 2, 3]
self.assertNotEqual(id(test_list), id(to_list(test_list)))
def test_sort(self):
data = [3, 1, 2]
self.assertEqual([1, 2, 3], sort_list(data))
string_data = '123'
self.assertEqual(string_data, sort_list(string_data))
def test_dict_diff(self):
base = dict(obj2=dict(), b1=True, b2=False, b3=False,
one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
l1=[1, 3], l2=[1, 2, 3], l4=[4],
nested=dict(n1=dict(n2=2)))
other = dict(b1=True, b2=False, b3=True, b4=True,
one=1, three=4, four=4, obj1=dict(key1=2),
l1=[2, 1], l2=[3, 2, 1], l3=[1],
nested=dict(n1=dict(n2=2, n3=3)))
result = dict_diff(base, other)
# string assertions
self.assertNotIn('one', result)
self.assertNotIn('two', result)
self.assertEqual(result['three'], 4)
self.assertEqual(result['four'], 4)
# dict assertions
self.assertIn('obj1', result)
self.assertIn('key1', result['obj1'])
self.assertNotIn('key2', result['obj1'])
# list assertions
self.assertEqual(result['l1'], [2, 1])
self.assertNotIn('l2', result)
self.assertEqual(result['l3'], [1])
self.assertNotIn('l4', result)
# nested assertions
self.assertIn('obj1', result)
self.assertEqual(result['obj1']['key1'], 2)
self.assertNotIn('key2', result['obj1'])
# bool assertions
self.assertNotIn('b1', result)
self.assertNotIn('b2', result)
self.assertTrue(result['b3'])
self.assertTrue(result['b4'])
def test_dict_combine(self):
base = dict(obj2=dict(), b1=True, b2=False, b3=False,
one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
l1=[1, 3], l2=[1, 2, 3], l4=[4],
nested=dict(n1=dict(n2=2)))
other = dict(b1=True, b2=False, b3=True, b4=True,
one=1, three=4, four=4, obj1=dict(key1=2),
l1=[2, 1], l2=[3, 2, 1], l3=[1],
nested=dict(n1=dict(n2=2, n3=3)))
result = dict_combine(base, other)
# string assertions
self.assertIn('one', result)
self.assertIn('two', result)
self.assertEqual(result['three'], 4)
self.assertEqual(result['four'], 4)
# dict assertions
self.assertIn('obj1', result)
self.assertIn('key1', result['obj1'])
self.assertIn('key2', result['obj1'])
# list assertions
self.assertEqual(result['l1'], [1, 2, 3])
self.assertIn('l2', result)
self.assertEqual(result['l3'], [1])
self.assertIn('l4', result)
# nested assertions
self.assertIn('obj1', result)
self.assertEqual(result['obj1']['key1'], 2)
self.assertIn('key2', result['obj1'])
# bool assertions
self.assertIn('b1', result)
self.assertIn('b2', result)
self.assertTrue(result['b3'])
self.assertTrue(result['b4'])