rearranged math filters

This commit is contained in:
Brian Coca 2015-02-18 20:03:05 -05:00
parent 94aca71546
commit 8872bba21f
3 changed files with 60 additions and 58 deletions

View file

@ -23,7 +23,6 @@ import types
import pipes import pipes
import glob import glob
import re import re
import collections
import crypt import crypt
import hashlib import hashlib
import string import string
@ -182,51 +181,6 @@ def ternary(value, true_val, false_val):
else: else:
return false_val return false_val
def unique(a):
if isinstance(a,collections.Hashable):
c = set(a)
else:
c = []
for x in a:
if x not in c:
c.append(x)
return c
def intersect(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) & set(b)
else:
c = unique(filter(lambda x: x in b, a))
return c
def difference(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) - set(b)
else:
c = unique(filter(lambda x: x not in b, a))
return c
def symmetric_difference(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) ^ set(b)
else:
c = unique(filter(lambda x: x not in intersect(a,b), union(a,b)))
return c
def union(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) | set(b)
else:
c = unique(a + b)
return c
def min(a):
_min = __builtins__.get('min')
return _min(a);
def max(a):
_max = __builtins__.get('max')
return _max(a);
def version_compare(value, version, operator='eq', strict=False): def version_compare(value, version, operator='eq', strict=False):
''' Perform a version comparison on a value ''' ''' Perform a version comparison on a value '''
@ -386,14 +340,6 @@ class FilterModule(object):
'ternary': ternary, 'ternary': ternary,
# list # list
'unique' : unique,
'intersect': intersect,
'difference': difference,
'symmetric_difference': symmetric_difference,
'union': union,
'min' : min,
'max' : max,
# version comparison # version comparison
'version_compare': version_compare, 'version_compare': version_compare,

View file

@ -15,11 +15,56 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import math import math
import collections
from ansible import errors from ansible import errors
def unique(a):
if isinstance(a,collections.Hashable):
c = set(a)
else:
c = []
for x in a:
if x not in c:
c.append(x)
return c
def intersect(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) & set(b)
else:
c = unique(filter(lambda x: x in b, a))
return c
def difference(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) - set(b)
else:
c = unique(filter(lambda x: x not in b, a))
return c
def symmetric_difference(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) ^ set(b)
else:
c = unique(filter(lambda x: x not in intersect(a,b), union(a,b)))
return c
def union(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) | set(b)
else:
c = unique(a + b)
return c
def min(a):
_min = __builtins__.get('min')
return _min(a);
def max(a):
_max = __builtins__.get('max')
return _max(a);
def isnotanumber(x): def isnotanumber(x):
try: try:
return math.isnan(x) return math.isnan(x)
@ -61,9 +106,19 @@ class FilterModule(object):
return { return {
# general math # general math
'isnan': isnotanumber, 'isnan': isnotanumber,
'min' : min,
'max' : max,
# exponents and logarithms # exponents and logarithms
'log': logarithm, 'log': logarithm,
'pow': power, 'pow': power,
'root': inversepower, 'root': inversepower,
# set theory
'unique' : unique,
'intersect': intersect,
'difference': difference,
'symmetric_difference': symmetric_difference,
'union': union,
} }

View file

@ -6,6 +6,7 @@ import os.path
import unittest, tempfile, shutil import unittest, tempfile, shutil
from ansible import playbook, inventory, callbacks from ansible import playbook, inventory, callbacks
import ansible.runner.filter_plugins.core import ansible.runner.filter_plugins.core
import ansible.runner.filter_plugins.mathstuff
INVENTORY = inventory.Inventory(['localhost']) INVENTORY = inventory.Inventory(['localhost'])
@ -182,9 +183,9 @@ class TestFilters(unittest.TestCase):
self.assertTrue(ansible.runner.filter_plugins.core.version_compare('12.04', 12, 'ge')) self.assertTrue(ansible.runner.filter_plugins.core.version_compare('12.04', 12, 'ge'))
def test_min(self): def test_min(self):
a = ansible.runner.filter_plugins.core.min([3, 2, 5, 4]) a = ansible.runner.filter_plugins.mathstuff.min([3, 2, 5, 4])
assert a == 2 assert a == 2
def test_max(self): def test_max(self):
a = ansible.runner.filter_plugins.core.max([3, 2, 5, 4]) a = ansible.runner.filter_plugins.mathstuff.max([3, 2, 5, 4])
assert a == 5 assert a == 5