Fixing filter plugins directory from switch

This commit is contained in:
James Cammarata 2015-05-04 01:33:10 -05:00
parent 249fd2a7e1
commit 803fb397f3
14 changed files with 1166 additions and 33 deletions

View file

@ -180,7 +180,8 @@ class TaskExecutor:
final_items = []
for item in items:
variables['item'] = item
if self._task.evaluate_conditional(variables):
templar = Templar(loader=self._loader, shared_loader_obj=self._shared_loader_obj, variables=variables)
if self._task.evaluate_conditional(templar, variables):
final_items.append(item)
return [",".join(final_items)]
else:
@ -208,13 +209,13 @@ class TaskExecutor:
# get the connection and the handler for this execution
self._connection = self._get_connection(variables)
self._handler = self._get_action_handler(connection=self._connection)
self._handler = self._get_action_handler(connection=self._connection, templar=templar)
# Evaluate the conditional (if any) for this task, which we do before running
# the final task post-validation. We do this before the post validation due to
# the fact that the conditional may specify that the task be skipped due to a
# variable not being present which would otherwise cause validation to fail
if not self._task.evaluate_conditional(variables):
if not self._task.evaluate_conditional(templar, variables):
debug("when evaulation failed, skipping this task")
return dict(changed=False, skipped=True, skip_reason='Conditional check failed')
@ -268,7 +269,7 @@ class TaskExecutor:
return dict(failed=True, msg="The async task did not return valid JSON: %s" % str(e))
if self._task.poll > 0:
result = self._poll_async_result(result=result)
result = self._poll_async_result(result=result, templar=templar)
# update the local copy of vars with the registered value, if specified,
# or any facts which may have been generated by the module execution
@ -284,15 +285,15 @@ class TaskExecutor:
# FIXME: make sure until is mutually exclusive with changed_when/failed_when
if self._task.until:
cond.when = self._task.until
if cond.evaluate_conditional(vars_copy):
if cond.evaluate_conditional(templar, vars_copy):
break
elif (self._task.changed_when or self._task.failed_when) and 'skipped' not in result:
if self._task.changed_when:
cond.when = [ self._task.changed_when ]
result['changed'] = cond.evaluate_conditional(vars_copy)
result['changed'] = cond.evaluate_conditional(templar, vars_copy)
if self._task.failed_when:
cond.when = [ self._task.failed_when ]
failed_when_result = cond.evaluate_conditional(vars_copy)
failed_when_result = cond.evaluate_conditional(templar, vars_copy)
result['failed_when_result'] = result['failed'] = failed_when_result
if failed_when_result:
break
@ -315,7 +316,7 @@ class TaskExecutor:
debug("attempt loop complete, returning result")
return result
def _poll_async_result(self, result):
def _poll_async_result(self, result, templar):
'''
Polls for the specified JID to be complete
'''
@ -339,6 +340,7 @@ class TaskExecutor:
connection=self._connection,
connection_info=self._connection_info,
loader=self._loader,
templar=templar,
shared_loader_obj=self._shared_loader_obj,
)
@ -391,7 +393,7 @@ class TaskExecutor:
return connection
def _get_action_handler(self, connection):
def _get_action_handler(self, connection, templar):
'''
Returns the correct action plugin to handle the requestion task action
'''
@ -411,6 +413,7 @@ class TaskExecutor:
connection=connection,
connection_info=self._connection_info,
loader=self._loader,
templar=templar,
shared_loader_obj=self._shared_loader_obj,
)

View file

@ -225,21 +225,21 @@ class Block(Base, Become, Conditional, Taggable):
ti.deserialize(ti_data)
self._task_include = ti
def evaluate_conditional(self, all_vars):
def evaluate_conditional(self, templar, all_vars):
if len(self._dep_chain):
for dep in self._dep_chain:
if not dep.evaluate_conditional(all_vars):
if not dep.evaluate_conditional(templar, all_vars):
return False
if self._task_include is not None:
if not self._task_include.evaluate_conditional(all_vars):
if not self._task_include.evaluate_conditional(templar, all_vars):
return False
if self._parent_block is not None:
if not self._parent_block.evaluate_conditional(all_vars):
if not self._parent_block.evaluate_conditional(templar, all_vars):
return False
elif self._role is not None:
if not self._role.evaluate_conditional(all_vars):
if not self._role.evaluate_conditional(templar, all_vars):
return False
return super(Block, self).evaluate_conditional(all_vars)
return super(Block, self).evaluate_conditional(templar, all_vars)
def set_loader(self, loader):
self._loader = loader

View file

@ -47,16 +47,16 @@ class Conditional:
if not isinstance(value, list):
setattr(self, name, [ value ])
def evaluate_conditional(self, all_vars):
def evaluate_conditional(self, templar, all_vars):
'''
Loops through the conditionals set on this object, returning
False if any of them evaluate as such.
'''
templar = Templar(loader=self._loader, variables=all_vars, fail_on_undefined=False)
for conditional in self.when:
if not self._check_conditional(conditional, templar, all_vars):
return False
return True
def _check_conditional(self, conditional, templar, all_vars):

View file

@ -266,14 +266,14 @@ class Task(Base, Conditional, Taggable, Become):
super(Task, self).deserialize(data)
def evaluate_conditional(self, all_vars):
def evaluate_conditional(self, templar, all_vars):
if self._block is not None:
if not self._block.evaluate_conditional(all_vars):
if not self._block.evaluate_conditional(templar, all_vars):
return False
if self._task_include is not None:
if not self._task_include.evaluate_conditional(all_vars):
if not self._task_include.evaluate_conditional(templar, all_vars):
return False
return super(Task, self).evaluate_conditional(all_vars)
return super(Task, self).evaluate_conditional(templar, all_vars)
def set_loader(self, loader):
'''

View file

@ -44,11 +44,12 @@ class ActionBase:
action in use.
'''
def __init__(self, task, connection, connection_info, loader, shared_loader_obj):
def __init__(self, task, connection, connection_info, loader, templar, shared_loader_obj):
self._task = task
self._connection = connection
self._connection_info = connection_info
self._loader = loader
self._templar = templar
self._shared_loader_obj = shared_loader_obj
self._shell = self.get_shell()

View file

@ -48,7 +48,7 @@ class ActionModule(ActionBase):
cond = Conditional(loader=self._loader)
for that in thats:
cond.when = [ that ]
test_result = cond.evaluate_conditional(all_vars=task_vars)
test_result = cond.evaluate_conditional(templar=self._templar, all_vars=task_vars)
if not test_result:
result = dict(
failed = True,

View file

@ -19,7 +19,6 @@ __metaclass__ = type
from ansible.plugins.action import ActionBase
from ansible.utils.boolean import boolean
from ansible.template import Templar
class ActionModule(ActionBase):
''' Print statements during execution '''
@ -35,8 +34,7 @@ class ActionModule(ActionBase):
result = dict(msg=self._task.args['msg'])
# FIXME: move the LOOKUP_REGEX somewhere else
elif 'var' in self._task.args: # and not utils.LOOKUP_REGEX.search(self._task.args['var']):
templar = Templar(loader=self._loader, shared_loader_obj=self._shared_loader_obj, variables=task_vars)
results = templar.template(self._task.args['var'], convert_bare=True)
results = self._templar.template(self._task.args['var'], convert_bare=True)
result = dict()
result[self._task.args['var']] = results
else:

View file

@ -19,7 +19,6 @@ __metaclass__ = type
from ansible.errors import AnsibleError
from ansible.plugins.action import ActionBase
from ansible.template import Templar
from ansible.utils.boolean import boolean
class ActionModule(ActionBase):
@ -27,11 +26,10 @@ class ActionModule(ActionBase):
TRANSFERS_FILES = False
def run(self, tmp=None, task_vars=dict()):
templar = Templar(loader=self._loader, variables=task_vars)
facts = dict()
if self._task.args:
for (k, v) in self._task.args.iteritems():
k = templar.template(k)
k = self._templar.template(k)
if isinstance(v, basestring) and v.lower() in ('true', 'false', 'yes', 'no'):
v = boolean(v)
facts[k] = v

View file

@ -21,7 +21,6 @@ import base64
import os
from ansible.plugins.action import ActionBase
from ansible.template import Templar
from ansible.utils.hashing import checksum_s
class ActionModule(ActionBase):
@ -99,11 +98,10 @@ class ActionModule(ActionBase):
dest = os.path.join(dest, base)
# template the source data locally & get ready to transfer
templar = Templar(loader=self._loader, variables=task_vars)
try:
with open(source, 'r') as f:
template_data = f.read()
resultant = templar.template(template_data, preserve_trailing_newlines=True)
resultant = self._templar.template(template_data, preserve_trailing_newlines=True)
except Exception as e:
return dict(failed=True, msg=type(e).__name__ + ": " + str(e))

View file

@ -1 +0,0 @@
../../../lib/ansible/runner/filter_plugins

View file

View file

@ -0,0 +1,351 @@
# (c) 2012, Jeroen Hoekx <jeroen@hoekx.be>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import sys
import base64
import json
import os.path
import types
import pipes
import glob
import re
import crypt
import hashlib
import string
from functools import partial
import operator as py_operator
from random import SystemRandom, shuffle
import uuid
import yaml
from jinja2.filters import environmentfilter
from distutils.version import LooseVersion, StrictVersion
from ansible import errors
from ansible.utils.hashing import md5s, checksum_s
from ansible.utils.unicode import unicode_wrap, to_unicode
UUID_NAMESPACE_ANSIBLE = uuid.UUID('361E6D51-FAEC-444A-9079-341386DA8E2E')
def to_nice_yaml(*a, **kw):
'''Make verbose, human readable yaml'''
transformed = yaml.safe_dump(*a, indent=4, allow_unicode=True, default_flow_style=False, **kw)
return to_unicode(transformed)
def to_json(a, *args, **kw):
''' Convert the value to JSON '''
return json.dumps(a, *args, **kw)
def to_nice_json(a, *args, **kw):
'''Make verbose, human readable JSON'''
# python-2.6's json encoder is buggy (can't encode hostvars)
if sys.version_info < (2, 7):
try:
import simplejson
except ImportError:
pass
else:
try:
major = int(simplejson.__version__.split('.')[0])
except:
pass
else:
if major >= 2:
return simplejson.dumps(a, indent=4, sort_keys=True, *args, **kw)
# Fallback to the to_json filter
return to_json(a, *args, **kw)
return json.dumps(a, indent=4, sort_keys=True, *args, **kw)
def failed(*a, **kw):
''' Test if task result yields failed '''
item = a[0]
if type(item) != dict:
raise errors.AnsibleFilterError("|failed expects a dictionary")
rc = item.get('rc',0)
failed = item.get('failed',False)
if rc != 0 or failed:
return True
else:
return False
def success(*a, **kw):
''' Test if task result yields success '''
return not failed(*a, **kw)
def changed(*a, **kw):
''' Test if task result yields changed '''
item = a[0]
if type(item) != dict:
raise errors.AnsibleFilterError("|changed expects a dictionary")
if not 'changed' in item:
changed = False
if ('results' in item # some modules return a 'results' key
and type(item['results']) == list
and type(item['results'][0]) == dict):
for result in item['results']:
changed = changed or result.get('changed', False)
else:
changed = item.get('changed', False)
return changed
def skipped(*a, **kw):
''' Test if task result yields skipped '''
item = a[0]
if type(item) != dict:
raise errors.AnsibleFilterError("|skipped expects a dictionary")
skipped = item.get('skipped', False)
return skipped
def mandatory(a):
''' Make a variable mandatory '''
try:
a
except NameError:
raise errors.AnsibleFilterError('Mandatory variable not defined.')
else:
return a
def bool(a):
''' return a bool for the arg '''
if a is None or type(a) == bool:
return a
if type(a) in types.StringTypes:
a = a.lower()
if a in ['yes', 'on', '1', 'true', 1]:
return True
else:
return False
def quote(a):
''' return its argument quoted for shell usage '''
return pipes.quote(a)
def fileglob(pathname):
''' return list of matched files for glob '''
return glob.glob(pathname)
def regex(value='', pattern='', ignorecase=False, match_type='search'):
''' Expose `re` as a boolean filter using the `search` method by default.
This is likely only useful for `search` and `match` which already
have their own filters.
'''
if ignorecase:
flags = re.I
else:
flags = 0
_re = re.compile(pattern, flags=flags)
_bool = __builtins__.get('bool')
return _bool(getattr(_re, match_type, 'search')(value))
def match(value, pattern='', ignorecase=False):
''' Perform a `re.match` returning a boolean '''
return regex(value, pattern, ignorecase, 'match')
def search(value, pattern='', ignorecase=False):
''' Perform a `re.search` returning a boolean '''
return regex(value, pattern, ignorecase, 'search')
def regex_replace(value='', pattern='', replacement='', ignorecase=False):
''' Perform a `re.sub` returning a string '''
if not isinstance(value, basestring):
value = str(value)
if ignorecase:
flags = re.I
else:
flags = 0
_re = re.compile(pattern, flags=flags)
return _re.sub(replacement, value)
def ternary(value, true_val, false_val):
''' value ? true_val : false_val '''
if value:
return true_val
else:
return false_val
def version_compare(value, version, operator='eq', strict=False):
''' Perform a version comparison on a value '''
op_map = {
'==': 'eq', '=': 'eq', 'eq': 'eq',
'<': 'lt', 'lt': 'lt',
'<=': 'le', 'le': 'le',
'>': 'gt', 'gt': 'gt',
'>=': 'ge', 'ge': 'ge',
'!=': 'ne', '<>': 'ne', 'ne': 'ne'
}
if strict:
Version = StrictVersion
else:
Version = LooseVersion
if operator in op_map:
operator = op_map[operator]
else:
raise errors.AnsibleFilterError('Invalid operator type')
try:
method = getattr(py_operator, operator)
return method(Version(str(value)), Version(str(version)))
except Exception, e:
raise errors.AnsibleFilterError('Version comparison: %s' % e)
@environmentfilter
def rand(environment, end, start=None, step=None):
r = SystemRandom()
if isinstance(end, (int, long)):
if not start:
start = 0
if not step:
step = 1
return r.randrange(start, end, step)
elif hasattr(end, '__iter__'):
if start or step:
raise errors.AnsibleFilterError('start and step can only be used with integer values')
return r.choice(end)
else:
raise errors.AnsibleFilterError('random can only be used on sequences and integers')
def randomize_list(mylist):
try:
mylist = list(mylist)
shuffle(mylist)
except:
pass
return mylist
def get_hash(data, hashtype='sha1'):
try: # see if hash is supported
h = hashlib.new(hashtype)
except:
return None
h.update(data)
return h.hexdigest()
def get_encrypted_password(password, hashtype='sha512', salt=None):
# TODO: find a way to construct dynamically from system
cryptmethod= {
'md5': '1',
'blowfish': '2a',
'sha256': '5',
'sha512': '6',
}
hastype = hashtype.lower()
if hashtype in cryptmethod:
if salt is None:
r = SystemRandom()
salt = ''.join([r.choice(string.ascii_letters + string.digits) for _ in range(16)])
saltstring = "$%s$%s" % (cryptmethod[hashtype],salt)
encrypted = crypt.crypt(password,saltstring)
return encrypted
return None
def to_uuid(string):
return str(uuid.uuid5(UUID_NAMESPACE_ANSIBLE, str(string)))
class FilterModule(object):
''' Ansible core jinja2 filters '''
def filters(self):
return {
# base 64
'b64decode': partial(unicode_wrap, base64.b64decode),
'b64encode': partial(unicode_wrap, base64.b64encode),
# uuid
'to_uuid': to_uuid,
# json
'to_json': to_json,
'to_nice_json': to_nice_json,
'from_json': json.loads,
# yaml
'to_yaml': yaml.safe_dump,
'to_nice_yaml': to_nice_yaml,
'from_yaml': yaml.safe_load,
# path
'basename': partial(unicode_wrap, os.path.basename),
'dirname': partial(unicode_wrap, os.path.dirname),
'expanduser': partial(unicode_wrap, os.path.expanduser),
'realpath': partial(unicode_wrap, os.path.realpath),
'relpath': partial(unicode_wrap, os.path.relpath),
# failure testing
'failed' : failed,
'success' : success,
# changed testing
'changed' : changed,
# skip testing
'skipped' : skipped,
# variable existence
'mandatory': mandatory,
# value as boolean
'bool': bool,
# quote string for shell usage
'quote': quote,
# hash filters
# md5 hex digest of string
'md5': md5s,
# sha1 hex digeset of string
'sha1': checksum_s,
# checksum of string as used by ansible for checksuming files
'checksum': checksum_s,
# generic hashing
'password_hash': get_encrypted_password,
'hash': get_hash,
# file glob
'fileglob': fileglob,
# regex
'match': match,
'search': search,
'regex': regex,
'regex_replace': regex_replace,
# ? : ;
'ternary': ternary,
# list
# version comparison
'version_compare': version_compare,
# random stuff
'random': rand,
'shuffle': randomize_list,
}

View file

@ -0,0 +1,659 @@
# (c) 2014, Maciej Delmanowski <drybjed@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from functools import partial
try:
import netaddr
except ImportError:
# in this case, we'll make the filters return error messages (see bottom)
netaddr = None
else:
class mac_linux(netaddr.mac_unix):
pass
mac_linux.word_fmt = '%.2x'
from ansible import errors
# ---- IP address and network query helpers ----
def _empty_ipaddr_query(v, vtype):
# We don't have any query to process, so just check what type the user
# expects, and return the IP address in a correct format
if v:
if vtype == 'address':
return str(v.ip)
elif vtype == 'network':
return str(v)
def _6to4_query(v, vtype, value):
if v.version == 4:
if v.size == 1:
ipconv = str(v.ip)
elif v.size > 1:
if v.ip != v.network:
ipconv = str(v.ip)
else:
ipconv = False
if ipaddr(ipconv, 'public'):
numbers = list(map(int, ipconv.split('.')))
try:
return '2002:{:02x}{:02x}:{:02x}{:02x}::1/48'.format(*numbers)
except:
return False
elif v.version == 6:
if vtype == 'address':
if ipaddr(str(v), '2002::/16'):
return value
elif vtype == 'network':
if v.ip != v.network:
if ipaddr(str(v.ip), '2002::/16'):
return value
else:
return False
def _ip_query(v):
if v.size == 1:
return str(v.ip)
if v.size > 1:
if v.ip != v.network:
return str(v.ip)
def _gateway_query(v):
if v.size > 1:
if v.ip != v.network:
return str(v.ip) + '/' + str(v.prefixlen)
def _bool_ipaddr_query(v):
if v:
return True
def _broadcast_query(v):
if v.size > 1:
return str(v.broadcast)
def _cidr_query(v):
return str(v)
def _cidr_lookup_query(v, iplist, value):
try:
if v in iplist:
return value
except:
return False
def _host_query(v):
if v.size == 1:
return str(v)
elif v.size > 1:
if v.ip != v.network:
return str(v.ip) + '/' + str(v.prefixlen)
def _hostmask_query(v):
return str(v.hostmask)
def _int_query(v, vtype):
if vtype == 'address':
return int(v.ip)
elif vtype == 'network':
return str(int(v.ip)) + '/' + str(int(v.prefixlen))
def _ipv4_query(v, value):
if v.version == 6:
try:
return str(v.ipv4())
except:
return False
else:
return value
def _ipv6_query(v, value):
if v.version == 4:
return str(v.ipv6())
else:
return value
def _link_local_query(v, value):
v_ip = netaddr.IPAddress(str(v.ip))
if v.version == 4:
if ipaddr(str(v_ip), '169.254.0.0/24'):
return value
elif v.version == 6:
if ipaddr(str(v_ip), 'fe80::/10'):
return value
def _loopback_query(v, value):
v_ip = netaddr.IPAddress(str(v.ip))
if v_ip.is_loopback():
return value
def _multicast_query(v, value):
if v.is_multicast():
return value
def _net_query(v):
if v.size > 1:
if v.ip == v.network:
return str(v.network) + '/' + str(v.prefixlen)
def _netmask_query(v):
if v.size > 1:
return str(v.netmask)
def _network_query(v):
if v.size > 1:
return str(v.network)
def _prefix_query(v):
return int(v.prefixlen)
def _private_query(v, value):
if v.is_private():
return value
def _public_query(v, value):
v_ip = netaddr.IPAddress(str(v.ip))
if v_ip.is_unicast() and not v_ip.is_private() and \
not v_ip.is_loopback() and not v_ip.is_netmask() and \
not v_ip.is_hostmask():
return value
def _revdns_query(v):
v_ip = netaddr.IPAddress(str(v.ip))
return v_ip.reverse_dns
def _size_query(v):
return v.size
def _subnet_query(v):
return str(v.cidr)
def _type_query(v):
if v.size == 1:
return 'address'
if v.size > 1:
if v.ip != v.network:
return 'address'
else:
return 'network'
def _unicast_query(v, value):
if v.is_unicast():
return value
def _version_query(v):
return v.version
def _wrap_query(v, vtype, value):
if v.version == 6:
if vtype == 'address':
return '[' + str(v.ip) + ']'
elif vtype == 'network':
return '[' + str(v.ip) + ']/' + str(v.prefixlen)
else:
return value
# ---- HWaddr query helpers ----
def _bare_query(v):
v.dialect = netaddr.mac_bare
return str(v)
def _bool_hwaddr_query(v):
if v:
return True
def _cisco_query(v):
v.dialect = netaddr.mac_cisco
return str(v)
def _empty_hwaddr_query(v, value):
if v:
return value
def _linux_query(v):
v.dialect = mac_linux
return str(v)
def _postgresql_query(v):
v.dialect = netaddr.mac_pgsql
return str(v)
def _unix_query(v):
v.dialect = netaddr.mac_unix
return str(v)
def _win_query(v):
v.dialect = netaddr.mac_eui48
return str(v)
# ---- IP address and network filters ----
def ipaddr(value, query = '', version = False, alias = 'ipaddr'):
''' Check if string is an IP address or network and filter it '''
query_func_extra_args = {
'': ('vtype',),
'6to4': ('vtype', 'value'),
'cidr_lookup': ('iplist', 'value'),
'int': ('vtype',),
'ipv4': ('value',),
'ipv6': ('value',),
'link-local': ('value',),
'loopback': ('value',),
'lo': ('value',),
'multicast': ('value',),
'private': ('value',),
'public': ('value',),
'unicast': ('value',),
'wrap': ('vtype', 'value'),
}
query_func_map = {
'': _empty_ipaddr_query,
'6to4': _6to4_query,
'address': _ip_query,
'address/prefix': _gateway_query,
'bool': _bool_ipaddr_query,
'broadcast': _broadcast_query,
'cidr': _cidr_query,
'cidr_lookup': _cidr_lookup_query,
'gateway': _gateway_query,
'gw': _gateway_query,
'host': _host_query,
'host/prefix': _gateway_query,
'hostmask': _hostmask_query,
'hostnet': _gateway_query,
'int': _int_query,
'ip': _ip_query,
'ipv4': _ipv4_query,
'ipv6': _ipv6_query,
'link-local': _link_local_query,
'lo': _loopback_query,
'loopback': _loopback_query,
'multicast': _multicast_query,
'net': _net_query,
'netmask': _netmask_query,
'network': _network_query,
'prefix': _prefix_query,
'private': _private_query,
'public': _public_query,
'revdns': _revdns_query,
'router': _gateway_query,
'size': _size_query,
'subnet': _subnet_query,
'type': _type_query,
'unicast': _unicast_query,
'v4': _ipv4_query,
'v6': _ipv6_query,
'version': _version_query,
'wrap': _wrap_query,
}
vtype = None
if not value:
return False
elif value == True:
return False
# Check if value is a list and parse each element
elif isinstance(value, (list, tuple)):
_ret = []
for element in value:
if ipaddr(element, str(query), version):
_ret.append(ipaddr(element, str(query), version))
if _ret:
return _ret
else:
return list()
# Check if value is a number and convert it to an IP address
elif str(value).isdigit():
# We don't know what IP version to assume, so let's check IPv4 first,
# then IPv6
try:
if ((not version) or (version and version == 4)):
v = netaddr.IPNetwork('0.0.0.0/0')
v.value = int(value)
v.prefixlen = 32
elif version and version == 6:
v = netaddr.IPNetwork('::/0')
v.value = int(value)
v.prefixlen = 128
# IPv4 didn't work the first time, so it definitely has to be IPv6
except:
try:
v = netaddr.IPNetwork('::/0')
v.value = int(value)
v.prefixlen = 128
# The value is too big for IPv6. Are you a nanobot?
except:
return False
# We got an IP address, let's mark it as such
value = str(v)
vtype = 'address'
# value has not been recognized, check if it's a valid IP string
else:
try:
v = netaddr.IPNetwork(value)
# value is a valid IP string, check if user specified
# CIDR prefix or just an IP address, this will indicate default
# output format
try:
address, prefix = value.split('/')
vtype = 'network'
except:
vtype = 'address'
# value hasn't been recognized, maybe it's a numerical CIDR?
except:
try:
address, prefix = value.split('/')
address.isdigit()
address = int(address)
prefix.isdigit()
prefix = int(prefix)
# It's not numerical CIDR, give up
except:
return False
# It is something, so let's try and build a CIDR from the parts
try:
v = netaddr.IPNetwork('0.0.0.0/0')
v.value = address
v.prefixlen = prefix
# It's not a valid IPv4 CIDR
except:
try:
v = netaddr.IPNetwork('::/0')
v.value = address
v.prefixlen = prefix
# It's not a valid IPv6 CIDR. Give up.
except:
return False
# We have a valid CIDR, so let's write it in correct format
value = str(v)
vtype = 'network'
# We have a query string but it's not in the known query types. Check if
# that string is a valid subnet, if so, we can check later if given IP
# address/network is inside that specific subnet
try:
### ?? 6to4 and link-local were True here before. Should they still?
if query and (query not in query_func_map or query == 'cidr_lookup') and ipaddr(query, 'network'):
iplist = netaddr.IPSet([netaddr.IPNetwork(query)])
query = 'cidr_lookup'
except:
pass
# This code checks if value maches the IP version the user wants, ie. if
# it's any version ("ipaddr()"), IPv4 ("ipv4()") or IPv6 ("ipv6()")
# If version does not match, return False
if version and v.version != version:
return False
extras = []
for arg in query_func_extra_args.get(query, tuple()):
extras.append(locals()[arg])
try:
return query_func_map[query](v, *extras)
except KeyError:
try:
float(query)
if v.size == 1:
if vtype == 'address':
return str(v.ip)
elif vtype == 'network':
return str(v)
elif v.size > 1:
try:
return str(v[query]) + '/' + str(v.prefixlen)
except:
return False
else:
return value
except:
raise errors.AnsibleFilterError(alias + ': unknown filter type: %s' % query)
return False
def ipwrap(value, query = ''):
try:
if isinstance(value, (list, tuple)):
_ret = []
for element in value:
if ipaddr(element, query, version = False, alias = 'ipwrap'):
_ret.append(ipaddr(element, 'wrap'))
else:
_ret.append(element)
return _ret
else:
_ret = ipaddr(value, query, version = False, alias = 'ipwrap')
if _ret:
return ipaddr(_ret, 'wrap')
else:
return value
except:
return value
def ipv4(value, query = ''):
return ipaddr(value, query, version = 4, alias = 'ipv4')
def ipv6(value, query = ''):
return ipaddr(value, query, version = 6, alias = 'ipv6')
# Split given subnet into smaller subnets or find out the biggest subnet of
# a given IP address with given CIDR prefix
# Usage:
#
# - address or address/prefix | ipsubnet
# returns CIDR subnet of a given input
#
# - address/prefix | ipsubnet(cidr)
# returns number of possible subnets for given CIDR prefix
#
# - address/prefix | ipsubnet(cidr, index)
# returns new subnet with given CIDR prefix
#
# - address | ipsubnet(cidr)
# returns biggest subnet with given CIDR prefix that address belongs to
#
# - address | ipsubnet(cidr, index)
# returns next indexed subnet which contains given address
def ipsubnet(value, query = '', index = 'x'):
''' Manipulate IPv4/IPv6 subnets '''
try:
vtype = ipaddr(value, 'type')
if vtype == 'address':
v = ipaddr(value, 'cidr')
elif vtype == 'network':
v = ipaddr(value, 'subnet')
value = netaddr.IPNetwork(v)
except:
return False
if not query:
return str(value)
elif str(query).isdigit():
vsize = ipaddr(v, 'size')
query = int(query)
try:
float(index)
index = int(index)
if vsize > 1:
try:
return str(list(value.subnet(query))[index])
except:
return False
elif vsize == 1:
try:
return str(value.supernet(query)[index])
except:
return False
except:
if vsize > 1:
try:
return str(len(list(value.subnet(query))))
except:
return False
elif vsize == 1:
try:
return str(value.supernet(query)[0])
except:
return False
return False
# Returns the nth host within a network described by value.
# Usage:
#
# - address or address/prefix | nthhost(nth)
# returns the nth host within the given network
def nthhost(value, query=''):
''' Get the nth host within a given network '''
try:
vtype = ipaddr(value, 'type')
if vtype == 'address':
v = ipaddr(value, 'cidr')
elif vtype == 'network':
v = ipaddr(value, 'subnet')
value = netaddr.IPNetwork(v)
except:
return False
if not query:
return False
try:
vsize = ipaddr(v, 'size')
nth = int(query)
if value.size > nth:
return value[nth]
except ValueError:
return False
return False
# ---- HWaddr / MAC address filters ----
def hwaddr(value, query = '', alias = 'hwaddr'):
''' Check if string is a HW/MAC address and filter it '''
query_func_extra_args = {
'': ('value',),
}
query_func_map = {
'': _empty_hwaddr_query,
'bare': _bare_query,
'bool': _bool_hwaddr_query,
'cisco': _cisco_query,
'eui48': _win_query,
'linux': _linux_query,
'pgsql': _postgresql_query,
'postgresql': _postgresql_query,
'psql': _postgresql_query,
'unix': _unix_query,
'win': _win_query,
}
try:
v = netaddr.EUI(value)
except:
if query and query != 'bool':
raise errors.AnsibleFilterError(alias + ': not a hardware address: %s' % value)
extras = []
for arg in query_func_extra_args.get(query, tuple()):
extras.append(locals()[arg])
try:
return query_func_map[query](v, *extras)
except KeyError:
raise errors.AnsibleFilterError(alias + ': unknown filter type: %s' % query)
return False
def macaddr(value, query = ''):
return hwaddr(value, query, alias = 'macaddr')
def _need_netaddr(f_name, *args, **kwargs):
raise errors.AnsibleFilterError('The {0} filter requires python-netaddr be'
' installed on the ansible controller'.format(f_name))
# ---- Ansible filters ----
class FilterModule(object):
''' IP address and network manipulation filters '''
filter_map = {
# IP addresses and networks
'ipaddr': ipaddr,
'ipwrap': ipwrap,
'ipv4': ipv4,
'ipv6': ipv6,
'ipsubnet': ipsubnet,
'nthhost': nthhost,
# MAC / HW addresses
'hwaddr': hwaddr,
'macaddr': macaddr
}
def filters(self):
if netaddr:
return self.filter_map
else:
# Need to install python-netaddr for these filters to work
return dict((f, partial(_need_netaddr, f)) for f in self.filter_map)

View file

@ -0,0 +1,126 @@
# (c) 2014, Brian Coca <bcoca@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import math
import collections
from ansible import errors
def unique(a):
if isinstance(a,collections.Hashable):
c = set(a)
else:
c = []
for x in a:
if x not in c:
c.append(x)
return c
def intersect(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) & set(b)
else:
c = unique(filter(lambda x: x in b, a))
return c
def difference(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) - set(b)
else:
c = unique(filter(lambda x: x not in b, a))
return c
def symmetric_difference(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) ^ set(b)
else:
c = unique(filter(lambda x: x not in intersect(a,b), union(a,b)))
return c
def union(a, b):
if isinstance(a,collections.Hashable) and isinstance(b,collections.Hashable):
c = set(a) | set(b)
else:
c = unique(a + b)
return c
def min(a):
_min = __builtins__.get('min')
return _min(a);
def max(a):
_max = __builtins__.get('max')
return _max(a);
def isnotanumber(x):
try:
return math.isnan(x)
except TypeError:
return False
def logarithm(x, base=math.e):
try:
if base == 10:
return math.log10(x)
else:
return math.log(x, base)
except TypeError, e:
raise errors.AnsibleFilterError('log() can only be used on numbers: %s' % str(e))
def power(x, y):
try:
return math.pow(x, y)
except TypeError, e:
raise errors.AnsibleFilterError('pow() can only be used on numbers: %s' % str(e))
def inversepower(x, base=2):
try:
if base == 2:
return math.sqrt(x)
else:
return math.pow(x, 1.0/float(base))
except TypeError, e:
raise errors.AnsibleFilterError('root() can only be used on numbers: %s' % str(e))
class FilterModule(object):
''' Ansible math jinja2 filters '''
def filters(self):
return {
# general math
'isnan': isnotanumber,
'min' : min,
'max' : max,
# exponents and logarithms
'log': logarithm,
'pow': power,
'root': inversepower,
# set theory
'unique' : unique,
'intersect': intersect,
'difference': difference,
'symmetric_difference': symmetric_difference,
'union': union,
}