diff --git a/lib/ansible/module_utils/netcfg.py b/lib/ansible/module_utils/netcfg.py index 04f3804c744..28bef48c033 100644 --- a/lib/ansible/module_utils/netcfg.py +++ b/lib/ansible/module_utils/netcfg.py @@ -4,7 +4,7 @@ # still belong to the author of the module, and may assign their own license # to the complete work. # -# (c) 2016 Red Hat Inc. +# Copyright (c) 2015 Peter Sprygada, # # Redistribution and use in source and binary forms, with or without modification, # are permitted provided that the following conditions are met: @@ -25,19 +25,57 @@ # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # + +import itertools import re -from ansible.module_utils.six.moves import zip +from ansible.module_utils.six import string_types +from ansible.module_utils.six.moves import zip, zip_longest DEFAULT_COMMENT_TOKENS = ['#', '!', '/*', '*/'] + +def to_list(val): + if isinstance(val, (list, tuple)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +class Config(object): + + def __init__(self, connection): + self.connection = connection + + def __call__(self, commands, **kwargs): + lines = to_list(commands) + return self.connection.configure(lines, **kwargs) + + def load_config(self, commands, **kwargs): + commands = to_list(commands) + return self.connection.load_config(commands, **kwargs) + + def get_config(self, **kwargs): + return self.connection.get_config(**kwargs) + + def save_config(self): + return self.connection.save_config() + class ConfigLine(object): - def __init__(self, raw): - self.text = str(raw).strip() - self.raw = raw - self._children = list() - self._parents = list() + def __init__(self, text): + self.text = text + self.children = list() + self.parents = list() + self.raw = None + + @property + def line(self): + line = [p.text for p in self.parents] + line.append(self.text) + return ' '.join(line) def __str__(self): return self.raw @@ -48,217 +86,295 @@ class ConfigLine(object): def __ne__(self, other): return not self.__eq__(other) - def __getitem__(self, key): - for item in self._children: - if item.text == key: - return item - raise KeyError(key) - - @property - def line(self): - line = self.parents - line.append(self.text) - return ' '.join(line) - - @property - def children(self): - return _obj_to_text(self._children) - - @property - def parents(self): - return _obj_to_text(self._parents) - - @property - def path(self): - config = _obj_to_raw(self._parents) - config.append(self.raw) - return '\n'.join(config) - - def add_child(self, obj): - assert isinstance(obj, ConfigLine), 'child must be of type `ConfigLine`' - self._children.append(obj) - def ignore_line(text, tokens=None): for item in (tokens or DEFAULT_COMMENT_TOKENS): if text.startswith(item): return True -_obj_to_text = lambda x: [o.text for o in x] -_obj_to_raw = lambda x: [o.raw for o in x] +def get_next(iterable): + item, next_item = itertools.tee(iterable, 2) + next_item = itertools.islice(next_item, 1, None) + return zip_longest(item, next_item) + +def parse(lines, indent, comment_tokens=None): + toplevel = re.compile(r'\S') + childline = re.compile(r'^\s*(.+)$') + + ancestors = list() + config = list() + + for line in str(lines).split('\n'): + text = str(re.sub(r'([{};])', '', line)).strip() + + cfg = ConfigLine(text) + cfg.raw = line + + if not text or ignore_line(text, comment_tokens): + continue + + # handle top level commands + if toplevel.match(line): + ancestors = [cfg] + + # handle sub level commands + else: + match = childline.match(line) + line_indent = match.start(1) + level = int(line_indent / indent) + parent_level = level - 1 + + cfg.parents = ancestors[:level] + + if level > len(ancestors): + config.append(cfg) + continue + + for i in range(level, len(ancestors)): + ancestors.pop() + + ancestors.append(cfg) + ancestors[parent_level].children.append(cfg) + + config.append(cfg) + + return config def dumps(objects, output='block'): if output == 'block': - item = _obj_to_raw(objects) + items = [c.raw for c in objects] elif output == 'commands': - items = _obj_to_text(objects) + items = [c.text for c in objects] + elif output == 'lines': + items = list() + for obj in objects: + line = list() + line.extend([p.text for p in obj.parents]) + line.append(obj.text) + items.append(' '.join(line)) else: raise TypeError('unknown value supplied for keyword output') return '\n'.join(items) class NetworkConfig(object): - def __init__(self, indent=1, contents=None): - self._indent = indent - self._items = list() + def __init__(self, indent=None, contents=None, device_os=None): + self.indent = indent or 1 + self._config = list() + self._device_os = device_os + self._syntax = 'block' # block, lines, junos + + if self._device_os == 'junos': + self._syntax = 'junos' if contents: self.load(contents) @property def items(self): - return self._items - - def __getitem__(self, key): - for line in self: - if line.text == key: - return line - raise KeyError(key) - - def __iter__(self): - return iter(self._items) + return self._config def __str__(self): - return '\n'.join([c.raw for c in self.items]) + if self._device_os == 'junos': + return dumps(self.expand_line(self.items), 'lines') + return dumps(self.expand_line(self.items)) - def load(self, s): - self._items = self.parse(s) + def load(self, contents): + # Going to start adding device profiles post 2.2 + tokens = list(DEFAULT_COMMENT_TOKENS) + if self._device_os == 'sros': + tokens.append('echo') + self._config = parse(contents, indent=4, comment_tokens=tokens) + else: + self._config = parse(contents, indent=self.indent) - def loadfp(self, fp): - return self.load(open(fp).read()) + def load_from_file(self, filename): + self.load(open(filename).read()) - def parse(self, lines, comment_tokens=None): - toplevel = re.compile(r'\S') - childline = re.compile(r'^\s*(.+)$') - - ancestors = list() - config = list() - - curlevel = 0 - prevlevel = 0 - - for linenum, line in enumerate(str(lines).split('\n')): - text = str(re.sub(r'([{};])', '', line)).strip() - - cfg = ConfigLine(line) - - if not text or ignore_line(text, comment_tokens): - continue - - # handle top level commands - if toplevel.match(line): - ancestors = [cfg] - prevlevel = curlevel - curlevel = 0 - - # handle sub level commands - else: - match = childline.match(line) - line_indent = match.start(1) - - prevlevel = curlevel - curlevel = int(line_indent / self._indent) - - if (curlevel - 1) > prevlevel: - curlevel = prevlevel + 1 - - parent_level = curlevel - 1 - - cfg._parents = ancestors[:curlevel] - - if curlevel > len(ancestors): - config.append(cfg) - continue - - for i in range(curlevel, len(ancestors)): - ancestors.pop() - - ancestors.append(cfg) - ancestors[parent_level].add_child(cfg) - - config.append(cfg) - - return config + def get(self, path): + if isinstance(path, string_types): + path = [path] + for item in self._config: + if item.text == path[-1]: + parents = [p.text for p in item.parents] + if parents == path[:-1]: + return item def get_object(self, path): for item in self.items: if item.text == path[-1]: - if item.parents == path[:-1]: + parents = [p.text for p in item.parents] + if parents == path[:-1]: return item - def get_section(self, path): - assert isinstance(path, list), 'path argument must be a list object' + def get_section_objects(self, path): + if not isinstance(path, list): + path = [path] obj = self.get_object(path) if not obj: raise ValueError('path does not exist in config') - return self._expand_section(obj) + return self.expand_section(obj) - def _expand_section(self, configobj, S=None): - if S is None: - S = list() - S.append(configobj) - for child in configobj._children: - if child in S: - continue - self._expand_section(child, S) - return S + def search(self, regexp, path=None): + regex = re.compile(r'^%s' % regexp, re.M) - def _diff_line(self, other): - updates = list() - for item in self.items: - if item not in other: - updates.append(item) - return updates - - def _diff_strict(self, other): - updates = list() - for index, line in enumerate(self._items): - try: - if line != other._lines[index]: - updates.append(line) - except IndexError: - updates.append(line) - return updates - - def _diff_exact(self, other): - updates = list() - if len(other) != len(self._items): - updates.extend(self._items) + if path: + parent = self.get(path) + if not parent or not parent.children: + return + children = [c.text for c in parent.children] + data = '\n'.join(children) else: - for ours, theirs in zip(self._items, other): - if ours != theirs: - updates.extend(self._items) - break - return updates + data = str(self) - def difference(self, other, match='line', path=None, replace=None): - try: - meth = getattr(self, '_diff_%s' % match) - updates = meth(other) - except AttributeError: - raise TypeError('invalid value for match keyword argument, ' - 'valid values are line, strict, or exact') + match = regex.search(data) + if match: + if match.groups(): + values = match.groupdict().values() + groups = list(set(match.groups()).difference(values)) + return (groups, match.groupdict()) + else: + return match.group() + def findall(self, regexp): + regexp = r'%s' % regexp + return re.findall(regexp, str(self)) + + def expand_line(self, objs): visited = set() expanded = list() - - for item in updates: - for p in item._parents: + for o in objs: + for p in o.parents: if p not in visited: visited.add(p) expanded.append(p) - expanded.append(item) - visited.add(item) - + expanded.append(o) + visited.add(o) return expanded + def expand_section(self, configobj, S=None): + if S is None: + S = list() + S.append(configobj) + for child in configobj.children: + if child in S: + continue + self.expand_section(child, S) + return S + + def expand_block(self, objects, visited=None): + items = list() + + if not visited: + visited = set() + + for o in objects: + items.append(o) + visited.add(o) + for child in o.children: + items.extend(self.expand_block([child], visited)) + + return items + + def diff_line(self, other, path=None): + diff = list() + for item in self.items: + if item not in other: + diff.append(item) + return diff + + def diff_strict(self, other, path=None): + diff = list() + for index, item in enumerate(self.items): + try: + if item != other[index]: + diff.append(item) + except IndexError: + diff.append(item) + return diff + + def diff_exact(self, other, path=None): + diff = list() + if len(other) != len(self.items): + diff.extend(self.items) + else: + for ours, theirs in zip(self.items, other): + if ours != theirs: + diff.extend(self.items) + break + return diff + + def difference(self, other, path=None, match='line', replace='line'): + try: + if path and match != 'line': + try: + other = other.get_section_objects(path) + except ValueError: + other = list() + else: + other = other.items + func = getattr(self, 'diff_%s' % match) + updates = func(other, path=path) + except AttributeError: + raise + raise TypeError('invalid value for match keyword') + + if self._device_os == 'junos': + return updates + + if replace == 'block': + parents = list() + for u in updates: + if u.parents is None: + if u not in parents: + parents.append(u) + else: + for p in u.parents: + if p not in parents: + parents.append(p) + + return self.expand_block(parents) + + return self.expand_line(updates) + + def replace(self, patterns, repl, parents=None, add_if_missing=False, + ignore_whitespace=True): + + match = None + + parents = to_list(parents) or list() + patterns = [re.compile(r, re.I) for r in to_list(patterns)] + + for item in self.items: + for regexp in patterns: + text = item.text + if not ignore_whitespace: + text = item.raw + if regexp.search(text): + if item.text != repl: + if parents == [p.text for p in item.parents]: + match = item + break + + if match: + match.text = repl + indent = len(match.raw) - len(match.raw.lstrip()) + match.raw = repl.rjust(len(repl) + indent) + + elif add_if_missing: + self.add(repl, parents=parents) + + def add(self, lines, parents=None): + """Adds one or lines of configuration + """ + ancestors = list() offset = 0 obj = None ## global config command if not parents: - for line in lines: + for line in to_list(lines): item = ConfigLine(line) item.raw = line if item not in self.items: @@ -268,12 +384,12 @@ class NetworkConfig(object): for index, p in enumerate(parents): try: i = index + 1 - obj = self.get_section(parents[:i])[0] + obj = self.get_section_objects(parents[:i])[0] ancestors.append(obj) except ValueError: # add parent to config - offset = index * self._indent + offset = index * self.indent obj = ConfigLine(p) obj.raw = p.rjust(len(p) + offset) if ancestors: @@ -283,15 +399,15 @@ class NetworkConfig(object): ancestors.append(obj) # add child objects - for line in lines: + for line in to_list(lines): # check if child already exists for child in ancestors[-1].children: if child.text == line: break else: - offset = len(parents) * self._indent + offset = len(parents) * self.indent item = ConfigLine(line) item.raw = line.rjust(len(line) + offset) - item._parents = ancestors + item.parents = ancestors ancestors[-1].children.append(item) self.items.append(item) diff --git a/lib/ansible/module_utils/network_common.py b/lib/ansible/module_utils/network_common.py deleted file mode 100644 index d7523cf7a13..00000000000 --- a/lib/ansible/module_utils/network_common.py +++ /dev/null @@ -1,89 +0,0 @@ -# This code is part of Ansible, but is an independent component. -# This particular file snippet, and this file snippet only, is BSD licensed. -# Modules you write using this snippet, which is embedded dynamically by Ansible -# still belong to the author of the module, and may assign their own license -# to the complete work. -# -# (c) 2016 Red Hat Inc. -# -# Redistribution and use in source and binary forms, with or without modification, -# are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. -# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -from ansible.module_utils.six import iteritems - -class ComplexDict: - - def __init__(self, attrs): - self._attributes = attrs - self.attr_names = frozenset(self._attributes.keys()) - for name, attr in iteritems(self._attributes): - if attr.get('key'): - attr['required'] = True - - def __call__(self, value): - if isinstance(value, dict): - unknown = set(value.keys()).difference(self.attr_names) - if unknown: - raise ValueError('invalid keys: %s' % ','.join(unknown)) - for name, attr in iteritems(self._attributes): - if attr.get('required') and name not in value: - raise ValueError('missing required attribute %s' % name) - if not value.get(name): - value[name] = attr.get('default') - return value - else: - obj = {} - for name, attr in iteritems(self._attributes): - if attr.get('key'): - obj[name] = value - else: - obj[name] = attr.get('default') - return obj - - -class ComplexList: - - def __init__(self, attrs): - self._attributes = attrs - self.attr_names = frozenset(self._attributes.keys()) - for name, attr in iteritems(self._attributes): - if attr.get('key'): - attr['required'] = True - - - def __call__(self, values): - objects = list() - for value in values: - if isinstance(value, dict): - for name, attr in iteritems(self._attributes): - if attr.get('required') and name not in value: - raise ValueError('missing required attr %s' % name) - if not value.get(name): - value[name] = attr.get('default') - objects.append(value) - else: - obj = {} - for name, attr in iteritems(self._attributes): - if attr.get('key'): - obj[name] = value - else: - obj[name] = attr.get('default') - objects.append(obj) - return objects -