From 7cb489eca3bb167ac9e22b310075e944b8254a27 Mon Sep 17 00:00:00 2001 From: James Cammarata Date: Thu, 23 Oct 2014 15:39:27 -0500 Subject: [PATCH] Adding a data parsing class for v2 --- v2/ansible/errors/__init__.py | 101 ++++++++++++++++--- v2/ansible/parsing/__init__.py | 22 ----- v2/ansible/parsing/yaml/__init__.py | 115 +++++++++++++++++++++- v2/ansible/parsing/yaml/objects.py | 5 + v2/ansible/parsing/yaml/strings.py | 118 +++++++++++++++++++++++ v2/ansible/playbook/base.py | 9 +- v2/ansible/playbook/role.py | 11 +-- v2/ansible/playbook/task.py | 5 +- v2/test/errors/test_errors.py | 14 +-- v2/test/parsing/test_general.py | 104 -------------------- v2/test/parsing/yaml/test_data_loader.py | 64 ++++++++++++ v2/test/parsing/yaml/test_yaml.py | 100 ------------------- 12 files changed, 406 insertions(+), 262 deletions(-) create mode 100644 v2/ansible/parsing/yaml/strings.py delete mode 100644 v2/test/parsing/test_general.py create mode 100644 v2/test/parsing/yaml/test_data_loader.py delete mode 100644 v2/test/parsing/yaml/test_yaml.py diff --git a/v2/ansible/errors/__init__.py b/v2/ansible/errors/__init__.py index 67f4d0a78b9..e0c21d195bd 100644 --- a/v2/ansible/errors/__init__.py +++ b/v2/ansible/errors/__init__.py @@ -21,11 +21,30 @@ __metaclass__ = type import os +from ansible.parsing.yaml.strings import * + class AnsibleError(Exception): - def __init__(self, message, obj=None): - # we import this here to prevent an import loop with errors + ''' + This is the base class for all errors raised from Ansible code, + and can be instantiated with two optional parameters beyond the + error message to control whether detailed information is displayed + when the error occurred while parsing a data file of some kind. + + Usage: + + raise AnsibleError('some message here', obj=obj, show_content=True) + + Where "obj" is some subclass of ansible.parsing.yaml.objects.AnsibleBaseYAMLObject, + which should be returned by the DataLoader() class. + ''' + + def __init__(self, message, obj=None, show_content=True): + # we import this here to prevent an import loop problem, + # since the objects code also imports ansible.errors from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject - self._obj = obj + + self._obj = obj + self._show_content = show_content if isinstance(self._obj, AnsibleBaseYAMLObject): extended_error = self._get_extended_error() if extended_error: @@ -36,22 +55,80 @@ class AnsibleError(Exception): def __repr__(self): return self.message - def _get_line_from_file(self, filename, line_number): - with open(filename, 'r') as f: + def _get_error_lines_from_file(self, file_name, line_number): + ''' + Returns the line in the file which coresponds to the reported error + location, as well as the line preceeding it (if the error did not + occur on the first line), to provide context to the error. + ''' + + target_line = '' + prev_line = '' + + with open(file_name, 'r') as f: lines = f.readlines() - return lines[line_number] + + target_line = lines[line_number] + if line_number > 0: + prev_line = lines[line_number - 1] + + return (target_line, prev_line) def _get_extended_error(self): + ''' + Given an object reporting the location of the exception in a file, return + detailed information regarding it including: + + * the line which caused the error as well as the one preceeding it + * causes and suggested remedies for common syntax errors + + If this error was created with show_content=False, the reporting of content + is suppressed, as the file contents may be sensitive (ie. vault data). + ''' + error_message = '' try: (src_file, line_number, col_number) = self._obj.get_position_info() - error_message += 'The error occurred on line %d of the file %s:\n' % (line_number, src_file) - if src_file not in ('', ''): - responsible_line = self._get_line_from_file(src_file, line_number - 1) - if responsible_line: - error_message += responsible_line - error_message += (' ' * (col_number-1)) + '^' + error_message += YAML_POSITION_DETAILS % (src_file, line_number, col_number) + if src_file not in ('', '') and self._show_content: + (target_line, prev_line) = self._get_error_lines_from_file(src_file, line_number - 1) + if target_line: + stripped_line = target_line.replace(" ","") + arrow_line = (" " * (col_number-1)) + "^" + error_message += "%s\n%s\n%s\n" % (prev_line.rstrip(), target_line.rstrip(), arrow_line) + + # common error/remediation checking here: + # check for unquoted vars starting lines + if ('{{' in target_line and '}}' in target_line) and ('"{{' not in target_line or "'{{" not in target_line): + error_message += YAML_COMMON_UNQUOTED_VARIABLE_ERROR + # check for common dictionary mistakes + elif ":{{" in stripped_line and "}}" in stripped_line: + error_message += YAML_COMMON_DICT_ERROR + # check for common unquoted colon mistakes + elif len(target_line) and len(target_line) > 1 and len(target_line) > col_number and target_line[col_number] == ":" and target_line.count(':') > 1: + error_message += YAML_COMMON_UNQUOTED_COLON_ERROR + # otherwise, check for some common quoting mistakes + else: + parts = target_line.split(":") + if len(parts) > 1: + middle = parts[1].strip() + match = False + unbalanced = False + + if middle.startswith("'") and not middle.endswith("'"): + match = True + elif middle.startswith('"') and not middle.endswith('"'): + match = True + + if len(middle) > 0 and middle[0] in [ '"', "'" ] and middle[-1] in [ '"', "'" ] and target_line.count("'") > 2 or target_line.count('"') > 2: + unbalanced = True + + if match: + error_message += YAML_COMMON_PARTIALLY_QUOTED_LINE_ERROR + if unbalanced: + error_message += YAML_COMMON_UNBALANCED_QUOTES_ERROR + except IOError: error_message += '\n(could not open file to display line)' except IndexError: diff --git a/v2/ansible/parsing/__init__.py b/v2/ansible/parsing/__init__.py index 5f922a120f3..785fc459921 100644 --- a/v2/ansible/parsing/__init__.py +++ b/v2/ansible/parsing/__init__.py @@ -19,25 +19,3 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import json - -from ansible.errors import AnsibleParserError, AnsibleInternalError -from ansible.parsing.vault import VaultLib -from ansible.parsing.yaml import safe_load - -def load(data): - - if hasattr(data, 'read') and hasattr(data.read, '__call__'): - data = data.read() - - if isinstance(data, basestring): - try: - try: - return json.loads(data) - except: - return safe_load(data) - except: - raise AnsibleParserError("data was not valid yaml") - - raise AnsibleInternalError("expected file or string, got %s" % type(data)) - diff --git a/v2/ansible/parsing/yaml/__init__.py b/v2/ansible/parsing/yaml/__init__.py index 6cc55bfc846..6d121d991ec 100644 --- a/v2/ansible/parsing/yaml/__init__.py +++ b/v2/ansible/parsing/yaml/__init__.py @@ -19,9 +19,114 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -from yaml import load -from ansible.parsing.yaml.loader import AnsibleLoader +import json +import os + +from yaml import load, YAMLError + +from ansible.errors import AnsibleParserError + +from ansible.parsing.vault import VaultLib +from ansible.parsing.yaml.loader import AnsibleLoader +from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject +from ansible.parsing.yaml.strings import YAML_SYNTAX_ERROR + +class DataLoader(): + + ''' + The DataLoader class is used to load and parse YAML or JSON content, + either from a given file name or from a string that was previously + read in through other means. A Vault password can be specified, and + any vault-encrypted files will be decrypted. + + Data read from files will also be cached, so the file will never be + read from disk more than once. + + Usage: + + dl = DataLoader() + (or) + dl = DataLoader(vault_password='foo') + + ds = dl.load('...') + ds = dl.load_from_file('/path/to/file') + ''' + + _FILE_CACHE = dict() + + def __init__(self, vault_password=None): + self._vault = VaultLib(password=vault_password) + + def load(self, data, file_name='', show_content=True): + ''' + Creates a python datastructure from the given data, which can be either + a JSON or YAML string. + ''' + + try: + # we first try to load this data as JSON + return json.loads(data) + except: + try: + # if loading JSON failed for any reason, we go ahead + # and try to parse it as YAML instead + return self._safe_load(data) + except YAMLError, yaml_exc: + self._handle_error(yaml_exc, file_name, show_content) + + def load_from_file(self, file_name): + ''' Loads data from a file, which can contain either JSON or YAML. ''' + + # if the file has already been read in and cached, we'll + # return those results to avoid more file/vault operations + if file_name in self._FILE_CACHE: + return self._FILE_CACHE + + # read the file contents and load the data structure from them + (file_data, show_content) = self._get_file_contents(file_name) + parsed_data = self.load(data=file_data, file_name=file_name, show_content=show_content) + + # cache the file contents for next time + self._FILE_CACHE[file_name] = parsed_data + + return parsed_data + + def _safe_load(self, stream): + ''' Implements yaml.safe_load(), except using our custom loader class. ''' + return load(stream, AnsibleLoader) + + def _get_file_contents(self, file_name): + ''' + Reads the file contents from the given file name, and will decrypt them + if they are found to be vault-encrypted. + ''' + if not os.path.exists(file_name) or not os.path.isfile(file_name): + raise AnsibleParserError("the file_name '%s' does not exist, or is not readable" % file_name) + + show_content = True + try: + with open(file_name, 'r') as f: + data = f.read() + if self._vault.is_encrypted(data): + data = self._vault.decrypt(data) + show_content = False + return (data, show_content) + except (IOError, OSError) as e: + raise AnsibleParserError("an error occured while trying to read the file '%s': %s" % (file_name, str(e))) + + def _handle_error(self, yaml_exc, file_name, show_content): + ''' + Optionally constructs an object (AnsibleBaseYAMLObject) to encapsulate the + file name/position where a YAML exception occured, and raises an AnsibleParserError + to display the syntax exception information. + ''' + + # if the YAML exception contains a problem mark, use it to construct + # an object the error class can use to display the faulty line + err_obj = None + if hasattr(yaml_exc, 'problem_mark'): + err_obj = AnsibleBaseYAMLObject() + err_obj.set_position_info(file_name, yaml_exc.problem_mark.line + 1, yaml_exc.problem_mark.column + 1) + + raise AnsibleParserError(YAML_SYNTAX_ERROR, obj=err_obj, show_content=show_content) -def safe_load(stream): - ''' implements yaml.safe_load(), except using our custom loader class ''' - return load(stream, AnsibleLoader) diff --git a/v2/ansible/parsing/yaml/objects.py b/v2/ansible/parsing/yaml/objects.py index be687d1e148..ba89accd73a 100644 --- a/v2/ansible/parsing/yaml/objects.py +++ b/v2/ansible/parsing/yaml/objects.py @@ -32,6 +32,11 @@ class AnsibleBaseYAMLObject: def get_position_info(self): return (self._data_source, self._line_number, self._column_number) + def set_position_info(self, src, line, col): + self._data_source = src + self._line_number = line + self._column_number = col + def copy_position_info(obj): ''' copies the position info from another object ''' assert isinstance(obj, AnsibleBaseYAMLObject) diff --git a/v2/ansible/parsing/yaml/strings.py b/v2/ansible/parsing/yaml/strings.py new file mode 100644 index 00000000000..b7e304194fc --- /dev/null +++ b/v2/ansible/parsing/yaml/strings.py @@ -0,0 +1,118 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +__all__ = [ + 'YAML_SYNTAX_ERROR', + 'YAML_POSITION_DETAILS', + 'YAML_COMMON_DICT_ERROR', + 'YAML_COMMON_UNQUOTED_VARIABLE_ERROR', + 'YAML_COMMON_UNQUOTED_COLON_ERROR', + 'YAML_COMMON_PARTIALLY_QUOTED_LINE_ERROR', + 'YAML_COMMON_UNBALANCED_QUOTES_ERROR', +] + +YAML_SYNTAX_ERROR = """\ +Syntax Error while loading YAML. +""" + +YAML_POSITION_DETAILS = """\ +The error appears to have been in '%s': line %s, column %s, +but may actually be before there depending on the exact syntax problem. +""" + +YAML_COMMON_DICT_ERROR = """\ +This one looks easy to fix. YAML thought it was looking for the start of a +hash/dictionary and was confused to see a second "{". Most likely this was +meant to be an ansible template evaluation instead, so we have to give the +parser a small hint that we wanted a string instead. The solution here is to +just quote the entire value. + +For instance, if the original line was: + + app_path: {{ base_path }}/foo + +It should be written as: + + app_path: "{{ base_path }}/foo" +""" + +YAML_COMMON_UNQUOTED_VARIABLE_ERROR = """\ +We could be wrong, but this one looks like it might be an issue with +missing quotes. Always quote template expression brackets when they +start a value. For instance: + + with_items: + - {{ foo }} + +Should be written as: + + with_items: + - "{{ foo }}" +""" + +YAML_COMMON_UNQUOTED_COLON_ERROR = """\ +This one looks easy to fix. There seems to be an extra unquoted colon in the line +and this is confusing the parser. It was only expecting to find one free +colon. The solution is just add some quotes around the colon, or quote the +entire line after the first colon. + +For instance, if the original line was: + + copy: src=file.txt dest=/path/filename:with_colon.txt + +It can be written as: + + copy: src=file.txt dest='/path/filename:with_colon.txt' + +Or: + + copy: 'src=file.txt dest=/path/filename:with_colon.txt' +""" + +YAML_COMMON_PARTIALLY_QUOTED_LINE_ERROR = """\ +This one looks easy to fix. It seems that there is a value started +with a quote, and the YAML parser is expecting to see the line ended +with the same kind of quote. For instance: + + when: "ok" in result.stdout + +Could be written as: + + when: '"ok" in result.stdout' + +Or equivalently: + + when: "'ok' in result.stdout" +""" + +YAML_COMMON_UNBALANCED_QUOTES_ERROR = """\ +We could be wrong, but this one looks like it might be an issue with +unbalanced quotes. If starting a value with a quote, make sure the +line ends with the same set of quotes. For instance this arbitrary +example: + + foo: "bad" "wolf" + +Could be written as: + + foo: '"bad" "wolf"' +""" + diff --git a/v2/ansible/playbook/base.py b/v2/ansible/playbook/base.py index 59c329d453a..577a5dae22c 100644 --- a/v2/ansible/playbook/base.py +++ b/v2/ansible/playbook/base.py @@ -25,14 +25,17 @@ from io import FileIO from six import iteritems, string_types from ansible.playbook.attribute import Attribute, FieldAttribute -from ansible.parsing import load +from ansible.parsing.yaml import DataLoader class Base: _tags = FieldAttribute(isa='list') _when = FieldAttribute(isa='list') - def __init__(self): + def __init__(self, loader=DataLoader): + + # the data loader class is used to parse data from strings and files + self._loader = loader # each class knows attributes set upon it, see Task.py for example self._attributes = dict() @@ -64,7 +67,7 @@ class Base: assert ds is not None if isinstance(ds, string_types) or isinstance(ds, FileIO): - ds = load(ds) + ds = self._loader.load(ds) # we currently don't do anything with private attributes but may # later decide to filter them out of 'ds' here. diff --git a/v2/ansible/playbook/role.py b/v2/ansible/playbook/role.py index 88aecab9852..b68ce515835 100644 --- a/v2/ansible/playbook/role.py +++ b/v2/ansible/playbook/role.py @@ -23,14 +23,11 @@ from six import iteritems, string_types import os +from ansible.errors import AnsibleError +from ansible.parsing.yaml import DataLoader from ansible.playbook.attribute import FieldAttribute from ansible.playbook.base import Base from ansible.playbook.block import Block -from ansible.errors import AnsibleError - -# FIXME: this def was cruft from the old utils code, so we'll need -# to relocate it somewhere before we can use it -#from ansible.parsing import load_data_from_file from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping @@ -48,10 +45,10 @@ class Role(Base): _default_vars = FieldAttribute(isa='dict', default=dict()) _role_vars = FieldAttribute(isa='dict', default=dict()) - def __init__(self, vault_password=None): + def __init__(self, vault_password=None, loader=DataLoader): self._role_path = None self._vault_password = vault_password - super(Role, self).__init__() + super(Role, self).__init__(loader=loader) def __repr__(self): return self.get_name() diff --git a/v2/ansible/playbook/task.py b/v2/ansible/playbook/task.py index 91ca7558d6a..aa79d494104 100644 --- a/v2/ansible/playbook/task.py +++ b/v2/ansible/playbook/task.py @@ -26,6 +26,7 @@ from ansible.errors import AnsibleError from ansible.parsing.splitter import parse_kv from ansible.parsing.mod_args import ModuleArgsParser +from ansible.parsing.yaml import DataLoader from ansible.plugins import module_finder, lookup_finder class Task(Base): @@ -85,11 +86,11 @@ class Task(Base): _transport = FieldAttribute(isa='string') _until = FieldAttribute(isa='list') # ? - def __init__(self, block=None, role=None): + def __init__(self, block=None, role=None, loader=DataLoader): ''' constructors a task, without the Task.load classmethod, it will be pretty blank ''' self._block = block self._role = role - super(Task, self).__init__() + super(Task, self).__init__(loader) def get_name(self): ''' return the name of the task ''' diff --git a/v2/test/errors/test_errors.py b/v2/test/errors/test_errors.py index 5d1868a5a4a..5b24dc4345d 100644 --- a/v2/test/errors/test_errors.py +++ b/v2/test/errors/test_errors.py @@ -30,7 +30,7 @@ from ansible.compat.tests.mock import mock_open, patch class TestErrors(unittest.TestCase): def setUp(self): - self.message = 'this is the error message' + self.message = 'This is the error message' self.obj = AnsibleBaseYAMLObject() @@ -42,18 +42,18 @@ class TestErrors(unittest.TestCase): self.assertEqual(e.message, self.message) self.assertEqual(e.__repr__(), self.message) - @patch.object(AnsibleError, '_get_line_from_file') + @patch.object(AnsibleError, '_get_error_lines_from_file') def test_error_with_object(self, mock_method): self.obj._data_source = 'foo.yml' self.obj._line_number = 1 self.obj._column_number = 1 - mock_method.return_value = 'this is line 1\n' + mock_method.return_value = ('this is line 1\n', '') e = AnsibleError(self.message, self.obj) - self.assertEqual(e.message, 'this is the error message\nThe error occurred on line 1 of the file foo.yml:\nthis is line 1\n^') + self.assertEqual(e.message, "This is the error message\nThe error appears to have been in 'foo.yml': line 1, column 1,\nbut may actually be before there depending on the exact syntax problem.\n\nthis is line 1\n^\n") - def test_error_get_line_from_file(self): + def test_get_error_lines_from_file(self): m = mock_open() m.return_value.readlines.return_value = ['this is line 1\n'] @@ -63,12 +63,12 @@ class TestErrors(unittest.TestCase): self.obj._line_number = 1 self.obj._column_number = 1 e = AnsibleError(self.message, self.obj) - self.assertEqual(e.message, 'this is the error message\nThe error occurred on line 1 of the file foo.yml:\nthis is line 1\n^') + self.assertEqual(e.message, "This is the error message\nThe error appears to have been in 'foo.yml': line 1, column 1,\nbut may actually be before there depending on the exact syntax problem.\n\nthis is line 1\n^\n") # this line will not be found, as it is out of the index range self.obj._data_source = 'foo.yml' self.obj._line_number = 2 self.obj._column_number = 1 e = AnsibleError(self.message, self.obj) - self.assertEqual(e.message, 'this is the error message\nThe error occurred on line 2 of the file foo.yml:\n\n(specified line no longer in file, maybe it changed?)') + self.assertEqual(e.message, "This is the error message\nThe error appears to have been in 'foo.yml': line 2, column 1,\nbut may actually be before there depending on the exact syntax problem.\n\n(specified line no longer in file, maybe it changed?)") diff --git a/v2/test/parsing/test_general.py b/v2/test/parsing/test_general.py deleted file mode 100644 index b06038a5884..00000000000 --- a/v2/test/parsing/test_general.py +++ /dev/null @@ -1,104 +0,0 @@ -# (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -# Make coding more python3-ish -from __future__ import (absolute_import, division, print_function) -__metaclass__ = type - -from ansible.compat.tests import unittest -from ansible.errors import AnsibleInternalError, AnsibleParserError -from ansible.parsing import load - -import json -import yaml - -from io import FileIO - -class MockFile(FileIO): - - def __init__(self, ds, method='json'): - self.ds = ds - self.method = method - - def read(self): - if self.method == 'json': - return json.dumps(self.ds) - elif self.method == 'yaml': - return yaml.dump(self.ds) - elif self.method == 'fail': - return """ - AAARGGGGH: - ***** - THIS WON'T PARSE !!! - NOOOOOOOOOOOOOOOOOO - """ - else: - raise Exception("untestable serializer") - - def close(self): - pass - -class TestGeneralParsing(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - - def test_parse_json_from_string(self): - data = """ - { - "asdf" : "1234", - "jkl" : 5678 - } - """ - output = load(data) - self.assertEqual(output['asdf'], '1234') - self.assertEqual(output['jkl'], 5678) - - def test_parse_json_from_file(self): - output = load(MockFile(dict(a=1,b=2,c=3), 'json')) - self.assertEqual(output, dict(a=1,b=2,c=3)) - - def test_parse_yaml_from_dict(self): - data = """ - asdf: '1234' - jkl: 5678 - """ - output = load(data) - self.assertEqual(output['asdf'], '1234') - self.assertEqual(output['jkl'], 5678) - - def test_parse_yaml_from_file(self): - output = load(MockFile(dict(a=1,b=2,c=3),'yaml')) - self.assertEqual(output, dict(a=1,b=2,c=3)) - - def test_parse_fail(self): - data = """ - TEXT: - *** - NOT VALID - """ - self.assertRaises(AnsibleParserError, load, data) - - def test_parse_fail_from_file(self): - self.assertRaises(AnsibleParserError, load, MockFile(None,'fail')) - - def test_parse_fail_invalid_type(self): - self.assertRaises(AnsibleInternalError, load, 3000) - self.assertRaises(AnsibleInternalError, load, dict(a=1,b=2,c=3)) - diff --git a/v2/test/parsing/yaml/test_data_loader.py b/v2/test/parsing/yaml/test_data_loader.py new file mode 100644 index 00000000000..166a60ee5e2 --- /dev/null +++ b/v2/test/parsing/yaml/test_data_loader.py @@ -0,0 +1,64 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from yaml.scanner import ScannerError + +from ansible.compat.tests import unittest +from ansible.compat.tests.mock import patch +from ansible.errors import AnsibleParserError + +from ansible.parsing.yaml import DataLoader +from ansible.parsing.yaml.objects import AnsibleMapping + +class TestDataLoader(unittest.TestCase): + + def setUp(self): + # FIXME: need to add tests that utilize vault_password + self._loader = DataLoader() + + def tearDown(self): + pass + + @patch.object(DataLoader, '_get_file_contents') + def test_parse_json_from_file(self, mock_def): + mock_def.return_value = ("""{"a": 1, "b": 2, "c": 3}""", True) + output = self._loader.load_from_file('dummy_json.txt') + self.assertEqual(output, dict(a=1,b=2,c=3)) + + @patch.object(DataLoader, '_get_file_contents') + def test_parse_yaml_from_file(self, mock_def): + mock_def.return_value = (""" + a: 1 + b: 2 + c: 3 + """, True) + output = self._loader.load_from_file('dummy_yaml.txt') + self.assertEqual(output, dict(a=1,b=2,c=3)) + + @patch.object(DataLoader, '_get_file_contents') + def test_parse_fail_from_file(self, mock_def): + mock_def.return_value = (""" + TEXT: + *** + NOT VALID + """, True) + self.assertRaises(AnsibleParserError, self._loader.load_from_file, 'dummy_yaml_bad.txt') + diff --git a/v2/test/parsing/yaml/test_yaml.py b/v2/test/parsing/yaml/test_yaml.py deleted file mode 100644 index c468ef6d6fa..00000000000 --- a/v2/test/parsing/yaml/test_yaml.py +++ /dev/null @@ -1,100 +0,0 @@ -# (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -# Make coding more python3-ish -from __future__ import (absolute_import, division, print_function) -__metaclass__ = type - -from ansible.compat.tests import unittest - -from yaml.scanner import ScannerError - -from ansible.parsing.yaml import safe_load -from ansible.parsing.yaml.objects import AnsibleMapping - -# a single dictionary instance -data1 = '''--- -key: value -''' - -# multiple dictionary instances -data2 = '''--- -- key1: value1 -- key2: value2 - -- key3: value3 - - -- key4: value4 -''' - -# multiple dictionary instances with other nested -# dictionaries contained within those -data3 = '''--- -- key1: - subkey1: subvalue1 - subkey2: subvalue2 - subkey3: - subsubkey1: subsubvalue1 -- key2: - subkey4: subvalue4 -- list1: - - list1key1: list1value1 - list1key2: list1value2 - list1key3: list1value3 -''' - -bad_data1 = '''--- -foo: bar - bam: baz -''' - -class TestSafeLoad(unittest.TestCase): - - def setUp(self): - pass - - def tearDown(self): - pass - - def test_safe_load_bad(self): - # test the loading of bad yaml data - self.assertRaises(ScannerError, safe_load, bad_data1) - - def test_safe_load(self): - # test basic dictionary - res = safe_load(data1) - self.assertEqual(type(res), AnsibleMapping) - self.assertEqual(res._line_number, 2) - - # test data with multiple dictionaries - res = safe_load(data2) - self.assertEqual(len(res), 4) - self.assertEqual(res[0]._line_number, 2) - self.assertEqual(res[1]._line_number, 3) - self.assertEqual(res[2]._line_number, 5) - self.assertEqual(res[3]._line_number, 8) - - # test data with multiple sub-dictionaries - res = safe_load(data3) - self.assertEqual(len(res), 3) - self.assertEqual(res[0]._line_number, 2) - self.assertEqual(res[1]._line_number, 7) - self.assertEqual(res[2]._line_number, 9) - self.assertEqual(res[0]['key1']._line_number, 3) - self.assertEqual(res[1]['key2']._line_number, 8) - self.assertEqual(res[2]['list1'][0]._line_number, 10)