allow jinja2 unique filter compat (#45637)
* allow jinja2 unique filter compat * detect if unique is provided, fallback with warning * handle j2 specific params * now all filters using unique must pass environment * added env to tests also normalized on how we normally import and use exceptoins
This commit is contained in:
parent
f4f5d941e5
commit
32ec69d827
2 changed files with 81 additions and 39 deletions
|
@ -27,54 +27,93 @@ import collections
|
|||
import itertools
|
||||
import math
|
||||
|
||||
from ansible import errors
|
||||
from jinja2.filters import environmentfilter
|
||||
|
||||
from ansible.errors import AnsibleFilterError
|
||||
from ansible.module_utils import basic
|
||||
from ansible.module_utils.six import binary_type, text_type
|
||||
from ansible.module_utils.six.moves import zip, zip_longest
|
||||
from ansible.module_utils._text import to_native
|
||||
from ansible.module_utils._text import to_native, to_text
|
||||
|
||||
try:
|
||||
from jinja2.filters import do_unique
|
||||
HAS_UNIQUE = True
|
||||
except ImportError:
|
||||
HAS_UNIQUE = False
|
||||
|
||||
try:
|
||||
from __main__ import display
|
||||
except ImportError:
|
||||
from ansible.utils.display import Display
|
||||
display = Display()
|
||||
|
||||
|
||||
def unique(a):
|
||||
if isinstance(a, collections.Hashable):
|
||||
c = set(a)
|
||||
else:
|
||||
c = []
|
||||
for x in a:
|
||||
if x not in c:
|
||||
c.append(x)
|
||||
@environmentfilter
|
||||
def unique(environment, a, case_sensitive=False, attribute=None):
|
||||
|
||||
error = None
|
||||
try:
|
||||
if HAS_UNIQUE:
|
||||
c = set(do_unique(environment, a, case_sensitive=case_sensitive, attribute=attribute))
|
||||
except Exception as e:
|
||||
if case_sensitive or attribute:
|
||||
raise AnsibleFilterError("Jinja2's unique filter failed and we cannot fall back to Ansible's version "
|
||||
"as it does not support the parameters supplied", orig_exc=e)
|
||||
else:
|
||||
display.warning('Falling back to Ansible unique filter as Jinaj2 one failed: %s' % to_text(e))
|
||||
error = e
|
||||
|
||||
if not HAS_UNIQUE or error:
|
||||
|
||||
# handle Jinja2 specific attributes when using Ansible's version
|
||||
if case_sensitive or attribute:
|
||||
raise AnsibleFilterError("Ansible's unique filter does not support case_sensitive nor attribute parameters, "
|
||||
"you need a newer version of Jinja2 that provides their version of the filter.")
|
||||
|
||||
if isinstance(a, collections.Hashable):
|
||||
c = set(a)
|
||||
else:
|
||||
c = []
|
||||
for x in a:
|
||||
if x not in c:
|
||||
c.append(x)
|
||||
return c
|
||||
|
||||
|
||||
def intersect(a, b):
|
||||
@environmentfilter
|
||||
def intersect(environment, a, b):
|
||||
if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable):
|
||||
c = set(a) & set(b)
|
||||
else:
|
||||
c = unique([x for x in a if x in b])
|
||||
c = unique(environment, [x for x in a if x in b])
|
||||
return c
|
||||
|
||||
|
||||
def difference(a, b):
|
||||
@environmentfilter
|
||||
def difference(environment, a, b):
|
||||
if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable):
|
||||
c = set(a) - set(b)
|
||||
else:
|
||||
c = unique([x for x in a if x not in b])
|
||||
c = unique(environment, [x for x in a if x not in b])
|
||||
return c
|
||||
|
||||
|
||||
def symmetric_difference(a, b):
|
||||
@environmentfilter
|
||||
def symmetric_difference(environment, a, b):
|
||||
if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable):
|
||||
c = set(a) ^ set(b)
|
||||
else:
|
||||
isect = intersect(a, b)
|
||||
c = [x for x in union(a, b) if x not in isect]
|
||||
isect = intersect(environment, a, b)
|
||||
c = [x for x in union(environment, a, b) if x not in isect]
|
||||
return c
|
||||
|
||||
|
||||
def union(a, b):
|
||||
@environmentfilter
|
||||
def union(environment, a, b):
|
||||
if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable):
|
||||
c = set(a) | set(b)
|
||||
else:
|
||||
c = unique(a + b)
|
||||
c = unique(environment, a + b)
|
||||
return c
|
||||
|
||||
|
||||
|
@ -95,14 +134,14 @@ def logarithm(x, base=math.e):
|
|||
else:
|
||||
return math.log(x, base)
|
||||
except TypeError as e:
|
||||
raise errors.AnsibleFilterError('log() can only be used on numbers: %s' % str(e))
|
||||
raise AnsibleFilterError('log() can only be used on numbers: %s' % str(e))
|
||||
|
||||
|
||||
def power(x, y):
|
||||
try:
|
||||
return math.pow(x, y)
|
||||
except TypeError as e:
|
||||
raise errors.AnsibleFilterError('pow() can only be used on numbers: %s' % str(e))
|
||||
raise AnsibleFilterError('pow() can only be used on numbers: %s' % str(e))
|
||||
|
||||
|
||||
def inversepower(x, base=2):
|
||||
|
@ -112,7 +151,7 @@ def inversepower(x, base=2):
|
|||
else:
|
||||
return math.pow(x, 1.0 / float(base))
|
||||
except (ValueError, TypeError) as e:
|
||||
raise errors.AnsibleFilterError('root() can only be used on numbers: %s' % str(e))
|
||||
raise AnsibleFilterError('root() can only be used on numbers: %s' % str(e))
|
||||
|
||||
|
||||
def human_readable(size, isbits=False, unit=None):
|
||||
|
@ -120,7 +159,7 @@ def human_readable(size, isbits=False, unit=None):
|
|||
try:
|
||||
return basic.bytes_to_human(size, isbits, unit)
|
||||
except Exception:
|
||||
raise errors.AnsibleFilterError("human_readable() can't interpret following string: %s" % size)
|
||||
raise AnsibleFilterError("human_readable() can't interpret following string: %s" % size)
|
||||
|
||||
|
||||
def human_to_bytes(size, default_unit=None, isbits=False):
|
||||
|
@ -128,7 +167,7 @@ def human_to_bytes(size, default_unit=None, isbits=False):
|
|||
try:
|
||||
return basic.human_to_bytes(size, default_unit, isbits)
|
||||
except Exception:
|
||||
raise errors.AnsibleFilterError("human_to_bytes() can't interpret following string: %s" % size)
|
||||
raise AnsibleFilterError("human_to_bytes() can't interpret following string: %s" % size)
|
||||
|
||||
|
||||
def rekey_on_member(data, key, duplicates='error'):
|
||||
|
@ -141,7 +180,7 @@ def rekey_on_member(data, key, duplicates='error'):
|
|||
value would be duplicated or to overwrite previous entries if that's the case.
|
||||
"""
|
||||
if duplicates not in ('error', 'overwrite'):
|
||||
raise errors.AnsibleFilterError("duplicates parameter to rekey_on_member has unknown value: {0}".format(duplicates))
|
||||
raise AnsibleFilterError("duplicates parameter to rekey_on_member has unknown value: {0}".format(duplicates))
|
||||
|
||||
new_obj = {}
|
||||
|
||||
|
@ -150,24 +189,24 @@ def rekey_on_member(data, key, duplicates='error'):
|
|||
elif isinstance(data, collections.Iterable) and not isinstance(data, (text_type, binary_type)):
|
||||
iterate_over = data
|
||||
else:
|
||||
raise errors.AnsibleFilterError("Type is not a valid list, set, or dict")
|
||||
raise AnsibleFilterError("Type is not a valid list, set, or dict")
|
||||
|
||||
for item in iterate_over:
|
||||
if not isinstance(item, collections.Mapping):
|
||||
raise errors.AnsibleFilterError("List item is not a valid dict")
|
||||
raise AnsibleFilterError("List item is not a valid dict")
|
||||
|
||||
try:
|
||||
key_elem = item[key]
|
||||
except KeyError:
|
||||
raise errors.AnsibleFilterError("Key {0} was not found".format(key))
|
||||
raise AnsibleFilterError("Key {0} was not found".format(key))
|
||||
except Exception as e:
|
||||
raise errors.AnsibleFilterError(to_native(e))
|
||||
raise AnsibleFilterError(to_native(e))
|
||||
|
||||
# Note: if new_obj[key_elem] exists it will always be a non-empty dict (it will at
|
||||
# minimun contain {key: key_elem}
|
||||
if new_obj.get(key_elem, None):
|
||||
if duplicates == 'error':
|
||||
raise errors.AnsibleFilterError("Key {0} is not unique, cannot correctly turn into dict".format(key_elem))
|
||||
raise AnsibleFilterError("Key {0} is not unique, cannot correctly turn into dict".format(key_elem))
|
||||
elif duplicates == 'overwrite':
|
||||
new_obj[key_elem] = item
|
||||
else:
|
||||
|
|
|
@ -4,9 +4,10 @@
|
|||
# Make coding more python3-ish
|
||||
from __future__ import (absolute_import, division, print_function)
|
||||
__metaclass__ = type
|
||||
|
||||
import pytest
|
||||
|
||||
from jinja2 import Environment
|
||||
|
||||
import ansible.plugins.filter.mathstuff as ms
|
||||
from ansible.errors import AnsibleFilterError
|
||||
|
||||
|
@ -22,41 +23,43 @@ TWO_SETS_DATA = (([1, 2], [3, 4], ([], sorted([1, 2]), sorted([1, 2, 3, 4]), sor
|
|||
(['a', 'b', 'c'], ['d', 'c', 'e'], (['c'], sorted(['a', 'b']), sorted(['a', 'b', 'd', 'e']), sorted(['a', 'b', 'c', 'e', 'd']))),
|
||||
)
|
||||
|
||||
env = Environment()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('data, expected', UNIQUE_DATA)
|
||||
class TestUnique:
|
||||
def test_unhashable(self, data, expected):
|
||||
assert sorted(ms.unique(list(data))) == expected
|
||||
assert sorted(ms.unique(env, list(data))) == expected
|
||||
|
||||
def test_hashable(self, data, expected):
|
||||
assert sorted(ms.unique(tuple(data))) == expected
|
||||
assert sorted(ms.unique(env, tuple(data))) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
|
||||
class TestIntersect:
|
||||
def test_unhashable(self, dataset1, dataset2, expected):
|
||||
assert sorted(ms.intersect(list(dataset1), list(dataset2))) == expected[0]
|
||||
assert sorted(ms.intersect(env, list(dataset1), list(dataset2))) == expected[0]
|
||||
|
||||
def test_hashable(self, dataset1, dataset2, expected):
|
||||
assert sorted(ms.intersect(tuple(dataset1), tuple(dataset2))) == expected[0]
|
||||
assert sorted(ms.intersect(env, tuple(dataset1), tuple(dataset2))) == expected[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
|
||||
class TestDifference:
|
||||
def test_unhashable(self, dataset1, dataset2, expected):
|
||||
assert sorted(ms.difference(list(dataset1), list(dataset2))) == expected[1]
|
||||
assert sorted(ms.difference(env, list(dataset1), list(dataset2))) == expected[1]
|
||||
|
||||
def test_hashable(self, dataset1, dataset2, expected):
|
||||
assert sorted(ms.difference(tuple(dataset1), tuple(dataset2))) == expected[1]
|
||||
assert sorted(ms.difference(env, tuple(dataset1), tuple(dataset2))) == expected[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
|
||||
class TestSymmetricDifference:
|
||||
def test_unhashable(self, dataset1, dataset2, expected):
|
||||
assert sorted(ms.symmetric_difference(list(dataset1), list(dataset2))) == expected[2]
|
||||
assert sorted(ms.symmetric_difference(env, list(dataset1), list(dataset2))) == expected[2]
|
||||
|
||||
def test_hashable(self, dataset1, dataset2, expected):
|
||||
assert sorted(ms.symmetric_difference(tuple(dataset1), tuple(dataset2))) == expected[2]
|
||||
assert sorted(ms.symmetric_difference(env, tuple(dataset1), tuple(dataset2))) == expected[2]
|
||||
|
||||
|
||||
class TestMin:
|
||||
|
|
Loading…
Reference in a new issue