diff --git a/lib/ansible/arguments.py b/lib/ansible/arguments.py new file mode 100644 index 00000000000..9ea2597aefa --- /dev/null +++ b/lib/ansible/arguments.py @@ -0,0 +1,83 @@ +# Copyright: (c) 2018, Toshio Kuratomi +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +""" +Hold command line arguments for use in other modules +""" + +from abc import ABCMeta + +from ansible.module_utils.common._collections_compat import (Container, Mapping, Sequence, Set) +from ansible.module_utils.common.collections import ImmutableDict +from ansible.module_utils.six import add_metaclass, binary_type, text_type +from ansible.utils.singleton import Singleton + + +def _make_immutable(obj): + """Recursively convert a container and objects inside of it into immutable data types""" + if isinstance(obj, (text_type, binary_type)): + # Strings first because they are also sequences + return obj + elif isinstance(obj, Mapping): + temp_dict = {} + for key, value in obj.items(): + if isinstance(value, Container): + temp_dict[key] = _make_immutable(value) + else: + temp_dict[key] = value + return ImmutableDict(temp_dict) + elif isinstance(obj, Set): + temp_set = set() + for value in obj: + if isinstance(value, Container): + temp_set.add(_make_immutable(value)) + else: + temp_set.add(value) + return frozenset(temp_set) + elif isinstance(obj, Sequence): + temp_sequence = [] + for value in obj: + if isinstance(value, Container): + temp_sequence.append(_make_immutable(value)) + else: + temp_sequence.append(value) + return tuple(temp_sequence) + + return obj + + +class _ABCSingleton(Singleton, ABCMeta): + """ + Combine ABCMeta based classes with Singleton based classes + + Combine Singleton and ABCMeta so we have a metaclass that unambiguously knows which can override + the other. Useful for making new types of containers which are also Singletons. + """ + pass + + +class CLIArgs(ImmutableDict): + """Hold a parsed copy of cli arguments""" + def __init__(self, mapping): + toplevel = {} + for key, value in mapping.items(): + toplevel[key] = _make_immutable(value) + super(CLIArgs, self).__init__(toplevel) + + @classmethod + def from_options(cls, options): + return cls(vars(options)) + + +@add_metaclass(_ABCSingleton) +class GlobalCLIArgs(CLIArgs): + """ + Globally hold a parsed copy of cli arguments. + + Only one of these exist per program as it is for global context + """ + pass diff --git a/lib/ansible/module_utils/common/collections.py b/lib/ansible/module_utils/common/collections.py index cffe75ea86f..0a166cd4cf7 100644 --- a/lib/ansible/module_utils/common/collections.py +++ b/lib/ansible/module_utils/common/collections.py @@ -1,4 +1,5 @@ -# Copyright (c), Sviatoslav Sydorenko 2018 +# Copyright: (c) 2018, Sviatoslav Sydorenko +# Copyright: (c) 2018, Ansible Project # Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) """Collection of low-level utility functions.""" @@ -7,7 +8,52 @@ __metaclass__ = type from ansible.module_utils.six import binary_type, text_type -from ansible.module_utils.common._collections_compat import Sequence +from ansible.module_utils.common._collections_compat import Hashable, Mapping, Sequence + + +class ImmutableDict(Hashable, Mapping): + """Dictionary that cannot be updated""" + def __init__(self, *args, **kwargs): + self._store = dict(*args, **kwargs) + + def __getitem__(self, key): + return self._store[key] + + def __iter__(self): + return self._store.__iter__() + + def __len__(self): + return self._store.__len__() + + def __hash__(self): + return hash(frozenset(self.items())) + + def __repr__(self): + return 'ImmutableDict({0})'.format(repr(self._store)) + + def union(self, overriding_mapping): + """ + Create an ImmutableDict as a combination of the original and overriding_mapping + + :arg overriding_mapping: A Mapping of replacement and additional items + :return: A copy of the ImmutableDict with key-value pairs from the overriding_mapping added + + If any of the keys in overriding_mapping are already present in the original ImmutableDict, + the overriding_mapping item replaces the one in the original ImmutableDict. + """ + return ImmutableDict(self._store, **overriding_mapping) + + def difference(self, subtractive_iterable): + """ + Create an ImmutableDict as a combination of the original minus keys in subtractive_iterable + + :arg subtractive_iterable: Any iterable containing keys that should not be present in the + new ImmutableDict + :return: A copy of the ImmutableDict with keys from the subtractive_iterable removed + """ + remove_keys = frozenset(subtractive_iterable) + keys = (k for k in self._store.keys() if k not in remove_keys) + return ImmutableDict((k, self._store[k]) for k in keys) def is_string(seq): diff --git a/test/units/module_utils/common/collections.py b/test/units/module_utils/common/collections.py index cf7be6183d0..fcd867f0261 100644 --- a/test/units/module_utils/common/collections.py +++ b/test/units/module_utils/common/collections.py @@ -9,7 +9,7 @@ __metaclass__ = type import pytest from ansible.module_utils.common._collections_compat import Sequence -from ansible.module_utils.common.collections import is_iterable, is_sequence +from ansible.module_utils.common.collections import ImmutableDict, is_iterable, is_sequence class SeqStub: @@ -100,3 +100,57 @@ def test_iterable_including_strings(string_input): @pytest.mark.parametrize('string_input', TEST_STRINGS) def test_iterable_excluding_strings(string_input): assert not is_iterable(string_input, include_strings=False) + + +class TestImmutableDict: + def test_scalar(self): + imdict = ImmutableDict({1: 2}) + assert imdict[1] == 2 + + def test_string(self): + imdict = ImmutableDict({u'café': u'くらとみ'}) + assert imdict[u'café'] == u'くらとみ' + + def test_container(self): + imdict = ImmutableDict({(1, 2): ['1', '2']}) + assert imdict[(1, 2)] == ['1', '2'] + + def test_from_tuples(self): + imdict = ImmutableDict((('a', 1), ('b', 2))) + assert frozenset(imdict.items()) == frozenset((('a', 1), ('b', 2))) + + def test_from_kwargs(self): + imdict = ImmutableDict(a=1, b=2) + assert frozenset(imdict.items()) == frozenset((('a', 1), ('b', 2))) + + def test_immutable(self): + imdict = ImmutableDict({1: 2}) + + with pytest.raises(TypeError) as exc_info: + imdict[1] = 3 + assert exc_info.value.args[0] == "'ImmutableDict' object does not support item assignment" + + with pytest.raises(TypeError) as exc_info: + imdict[5] = 3 + assert exc_info.value.args[0] == "'ImmutableDict' object does not support item assignment" + + def test_hashable(self): + # ImmutableDict is hashable when all of its values are hashable + imdict = ImmutableDict({u'café': u'くらとみ'}) + assert hash(imdict) + + def test_nonhashable(self): + # ImmutableDict is unhashable when one of its values is unhashable + imdict = ImmutableDict({u'café': u'くらとみ', 1: [1, 2]}) + + with pytest.raises(TypeError) as exc_info: + hash(imdict) + assert exc_info.value.args[0] == "unhashable type: 'list'" + + def test_len(self): + imdict = ImmutableDict({1: 2, 'a': 'b'}) + assert len(imdict) == 2 + + def test_repr(self): + imdict = ImmutableDict({1: 2, 'a': 'b'}) + assert repr(imdict) == "ImmutableDict({1: 2, 'a': 'b'})" diff --git a/test/units/test_arguments.py b/test/units/test_arguments.py new file mode 100644 index 00000000000..2aecf0c52a6 --- /dev/null +++ b/test/units/test_arguments.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Copyright: (c) 2018, Toshio Kuratomi +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division) +__metaclass__ = type + +try: + import argparse +except ImportError: + argparse = None + +import optparse + +import pytest + +from ansible import arguments + + +MAKE_IMMUTABLE_DATA = ((u'くらとみ', u'くらとみ'), + (42, 42), + ({u'café': u'くらとみ'}, arguments.ImmutableDict({u'café': u'くらとみ'})), + ([1, u'café', u'くらとみ'], (1, u'café', u'くらとみ')), + (set((1, u'café', u'くらとみ')), frozenset((1, u'café', u'くらとみ'))), + ({u'café': [1, set(u'ñ')]}, + arguments.ImmutableDict({u'café': (1, frozenset(u'ñ'))})), + ([set((1, 2)), {u'くらとみ': 3}], + (frozenset((1, 2)), arguments.ImmutableDict({u'くらとみ': 3}))), + ) + + +@pytest.mark.parametrize('data, expected', MAKE_IMMUTABLE_DATA) +def test_make_immutable(data, expected): + assert arguments._make_immutable(data) == expected + + +def test_cliargs(): + class FakeOptions: + pass + options = FakeOptions() + options.tags = [u'production', u'webservers'] + options.check_mode = True + options.start_at_task = u'Start with くらとみ' + + expected = frozenset((('tags', (u'production', u'webservers')), + ('check_mode', True), + ('start_at_task', u'Start with くらとみ'))) + + assert frozenset(arguments.CLIArgs(options).items()) == expected + + +@pytest.mark.skipIf(argparse is None) +def test_cliargs_argparse(): + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument('integers', metavar='N', type=int, nargs='+', + help='an integer for the accumulator') + parser.add_argument('--sum', dest='accumulate', action='store_const', + const=sum, default=max, + help='sum the integers (default: find the max)') + args = parser.parse_args([u'--sum', u'1', u'2']) + + expected = frozenset((('accumulate', sum), ('integers', (1, 2)))) + + assert frozenset(arguments.CLIArgs.from_options(args).items()) == expected + + +# Can get rid of this test when we port ansible.cli from optparse to argparse +def test_cliargs_optparse(): + parser = optparse.OptionParser(description='Process some integers.') + parser.add_option('--sum', dest='accumulate', action='store_const', + const=sum, default=max, + help='sum the integers (default: find the max)') + opts, args = parser.parse_args([u'--sum', u'1', u'2']) + opts.integers = args + + expected = frozenset((('accumulate', sum), ('integers', (u'1', u'2')))) + + assert frozenset(arguments.CLIArgs.from_options(opts).items()) == expected