Teaching objects to load themselves, making the JSON/YAML parsing ambidexterous.

This commit is contained in:
Michael DeHaan 2014-10-08 15:59:24 -04:00
parent c75aeca435
commit 56b6cb5328
12 changed files with 180 additions and 59 deletions

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,85 @@
# TODO: header
import unittest
from ansible.parsing import load
from ansible.errors import AnsibleParserError
import json
class MockFile(file):
def __init__(self, ds, method='json'):
self.ds = ds
self.method = method
def read(self):
if method == 'json':
return json.dumps(ds)
elif method == 'yaml':
return yaml.dumps(ds)
elif method == 'fail':
return """
AAARGGGGH
THIS WON'T PARSE !!!
NOOOOOOOOOOOOOOOOOO
"""
else:
raise Exception("untestable serializer")
def close(self):
pass
class TestGeneralParsing(unittest.TestCase):
def __init__(self):
pass
def setUp(self):
pass
def tearDown(self):
pass
def parse_json_from_string(self):
input = """
{
"asdf" : "1234",
"jkl" : 5678
}
"""
output = load(input)
assert output['asdf'] == '1234'
assert output['jkl'] == 5678
def parse_json_from_file(self):
output = load(MockFile(dict(a=1,b=2,c=3)),'json')
assert ouput == dict(a=1,b=2,c=3)
def parse_yaml_from_dict(self):
input = """
asdf: '1234'
jkl: 5678
"""
output = load(input)
assert output['asdf'] == '1234'
assert output['jkl'] == 5678
def parse_yaml_from_file(self):
output = load(MockFile(dict(a=1,b=2,c=3),'yaml'))
assert output == dict(a=1,b=2,c=3)
def parse_fail(self):
input = """
TEXT
***
NOT VALID
"""
self.failUnlessRaises(load(input), AnsibleParserError)
def parse_fail_from_file(self):
self.failUnlessRaises(load(MockFile(None,'fail')), AnsibleParserError)
def parse_fail_invalid_type(self):
self.failUnlessRaises(3000, AnsibleParsingError)
self.failUnlessRaises(dict(a=1,b=2,c=3), AnsibleParserError)

View file

@ -5,10 +5,14 @@ import unittest
class TestModArgsDwim(unittest.TestCase):
# TODO: add tests that construct ModuleArgsParser with a task reference
# TODO: verify the AnsibleError raised on failure knows the task
# and the task knows the line numbers
def setUp(self):
self.m = ModuleArgsParser()
pass
def tearDown(self):
pass
@ -77,5 +81,4 @@ class TestModArgsDwim(unittest.TestCase):
mod, args, to = self.m.parse(dict(local_action='copy src=a dest=b'))
assert mod == 'copy'
assert args == dict(src='a', dest='b')
assert to is 'localhost'
assert to is 'localhost'

View file

@ -1,2 +1 @@
# TODO: header

View file

@ -16,13 +16,13 @@ class TestTask(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_construct_empty_task(self):
t = Task()
def test_construct_task_with_role(self):
pass
@ -57,15 +57,13 @@ class TestTask(unittest.TestCase):
pass
def test_can_load_module_complex_form(self):
pass
pass
def test_local_action_implies_delegate(self):
pass
pass
def test_local_action_conflicts_with_delegate(self):
pass
pass
def test_delegate_to_parses(self):
pass

View file

@ -16,4 +16,31 @@
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
class AnsibleError(Exception):
pass
def __init__(self, message, object=None):
self.message = message
self.object = object
# TODO: nice __repr__ message that includes the line number if the object
# it was constructed with had the line number
# TODO: tests for the line number functionality
class AnsibleParserError(AnsibleError):
''' something was detected early that is wrong about a playbook or data file '''
pass
class AnsibleInternalError(AnsibleError):
''' internal safeguards tripped, something happened in the code that should never happen '''
pass
class AnsibleRuntimeError(AnsibleError):
''' ansible had a problem while running a playbook '''
pass
class AnsibleModuleError(AnsibleRuntimeError):
''' a module failed somehow '''
pass
class AnsibleConnectionFailure(AnsibleRuntimeError):
''' the transport / connection_plugin had a fatal error '''
pass

View file

@ -1 +1,18 @@
# TODO: header
from ansible.errors import AnsibleError, AnsibleInternalError
def load(self, data):
if instanceof(data, file):
fd = open(f)
data = fd.read()
fd.close()
if instanceof(data, basestring):
try:
return json.loads(data)
except:
return safe_load(data)
raise AnsibleInternalError("expected file or string, got %s" % type(data))

View file

@ -55,15 +55,16 @@ class ModuleArgsParser(object):
will tell you about the modules in a predictable way.
"""
def __init__(self):
def __init__(self, task=None):
self._ds = None
self._task = task
def _get_delegate_to(self):
'''
Returns the value of the delegate_to key from the task datastructure,
or None if the value was not directly specified
'''
return self._ds.get('delegate_to')
return self._ds.get('delegate_to', None)
def _get_old_style_action(self):
'''
@ -108,29 +109,24 @@ class ModuleArgsParser(object):
if 'module' in other_args:
del other_args['module']
args.update(other_args)
elif isinstance(action_data, basestring):
action_data = action_data.strip()
if not action_data:
# TODO: change to an AnsibleParsingError so that the
# filename/line number can be reported in the error
raise AnsibleError("when using 'action:' or 'local_action:', the module name must be specified")
raise AnsibleError("when using 'action:' or 'local_action:', the module name must be specified", object=self._task)
else:
# split up the string based on spaces, where the first
# item specified must be a valid module name
parts = action_data.split(' ', 1)
action = parts[0]
if action not in module_finder:
# TODO: change to an AnsibleParsingError so that the
# filename/line number can be reported in the error
raise AnsibleError("the module '%s' was not found in the list of loaded modules")
raise AnsibleError("the module '%s' was not found in the list of loaded modules" % action, object=self._task)
if len(parts) > 1:
args = self._get_args_from_action(action, ' '.join(parts[1:]))
else:
args = {}
else:
# TODO: change to an AnsibleParsingError so that the
# filename/line number can be reported in the error
raise AnsibleError('module args must be specified as a dictionary or string')
raise AnsibleError('module args must be specified as a dictionary or string', object=self._task)
return dict(action=action, args=args, delegate_to=delegate_to)
@ -277,7 +273,7 @@ class ModuleArgsParser(object):
assert type(ds) == dict
self._ds = ds
# first we try to get the module action/args based on the
# new-style format, where the module name is the key
result = self._get_new_style_action()
@ -286,9 +282,7 @@ class ModuleArgsParser(object):
# where 'action' or 'local_action' is the key
result = self._get_old_style_action()
if result is None:
# TODO: change to an AnsibleParsingError so that the
# filename/line number can be reported in the error
raise AnsibleError('no action specified for this task')
raise AnsibleError('no action specified for this task', object=self._task)
# if the action is set to 'shell', we switch that to 'command' and
# set the special parameter '_uses_shell' to true in the args dict
@ -302,11 +296,8 @@ class ModuleArgsParser(object):
specified_delegate_to = self._get_delegate_to()
if specified_delegate_to is not None:
if result['delegate_to'] is not None:
# TODO: change to an AnsibleParsingError so that the
# filename/line number can be reported in the error
raise AnsibleError('delegate_to cannot be used with local_action')
else:
result['delegate_to'] = specified_delegate_to
return (result['action'], result['args'], result['delegate_to'])

View file

@ -4,4 +4,3 @@ from ansible.parsing.yaml.loader import AnsibleLoader
def safe_load(stream):
''' implements yaml.safe_load(), except using our custom loader class '''
return load(stream, AnsibleLoader)

View file

@ -31,4 +31,3 @@ class Attribute(object):
class FieldAttribute(Attribute):
pass

View file

@ -16,12 +16,13 @@
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from ansible.playbook.attribute import Attribute, FieldAttribute
from ansible.parsing import load as ds_load
class Base(object):
def __init__(self):
# each class knows attributes set upon it, see Task.py for example
# each class knows attributes set upon it, see Task.py for example
self._attributes = dict()
for (name, value) in self.__class__.__dict__.iteritems():
@ -39,6 +40,9 @@ class Base(object):
assert ds is not None
if isinstance(ds, basestring) or isinstance(ds, file):
ds = ds_load(ds)
# we currently don't do anything with private attributes but may
# later decide to filter them out of 'ds' here.
@ -59,14 +63,14 @@ class Base(object):
else:
if aname in ds:
self._attributes[aname] = ds[aname]
# return the constructed object
self.validate()
return self
def validate(self):
''' validation that is done at parse time, not load time '''
''' validation that is done at parse time, not load time '''
# walk all fields in the object
for (name, attribute) in self.__dict__.iteritems():
@ -76,9 +80,9 @@ class Base(object):
if not name.startswith("_"):
raise AnsibleError("FieldAttribute %s must start with _" % name)
aname = name[1:]
# run validator only if present
method = getattr(self, '_validate_%s' % (prefix, aname), None)
if method:
@ -87,9 +91,9 @@ class Base(object):
def post_validate(self, runner_context):
'''
we can't tell that everything is of the right type until we have
all the variables. Run basic types (from isa) as well as
all the variables. Run basic types (from isa) as well as
any _post_validate_<foo> functions.
'''
'''
raise exception.NotImplementedError
@ -107,4 +111,3 @@ class Base(object):
return self._attributes[needle]
raise AttributeError("attribute not found: %s" % needle)

View file

@ -27,7 +27,7 @@ from ansible.plugins import module_finder, lookup_finder
class Task(Base):
"""
A task is a language feature that represents a call to a module, with given arguments and other parameters.
A task is a language feature that represents a call to a module, with given arguments and other parameters.
A handler is a subclass of a task.
Usage:
@ -41,14 +41,14 @@ class Task(Base):
# load_<attribute_name> and
# validate_<attribute_name>
# will be used if defined
# might be possible to define others
# might be possible to define others
_args = FieldAttribute(isa='dict')
_action = FieldAttribute(isa='string')
_always_run = FieldAttribute(isa='bool')
_any_errors_fatal = FieldAttribute(isa='bool')
_async = FieldAttribute(isa='int')
_async = FieldAttribute(isa='int')
_connection = FieldAttribute(isa='string')
_delay = FieldAttribute(isa='int')
_delegate_to = FieldAttribute(isa='string')
@ -59,9 +59,9 @@ class Task(Base):
_loop = FieldAttribute(isa='string', private=True)
_loop_args = FieldAttribute(isa='list', private=True)
_local_action = FieldAttribute(isa='string')
# FIXME: this should not be a Task
_meta = FieldAttribute(isa='string')
_meta = FieldAttribute(isa='string')
_name = FieldAttribute(isa='string')
@ -120,7 +120,7 @@ class Task(Base):
def __repr__(self):
''' returns a human readable representation of the task '''
return "TASK: %s" % self.get_name()
def _munge_loop(self, ds, new_ds, k, v):
''' take a lookup plugin name and store it correctly '''
@ -128,9 +128,9 @@ class Task(Base):
raise AnsibleError("duplicate loop in task: %s" % k)
new_ds['loop'] = k
new_ds['loop_args'] = v
def munge(self, ds):
'''
'''
tasks are especially complex arguments so need pre-processing.
keep it short.
'''
@ -202,7 +202,7 @@ LEGACY = """
results['_module_name'] = k
if isinstance(v, dict) and 'args' in ds:
raise AnsibleError("can't combine args: and a dict for %s: in task %s" % (k, ds.get('name', "%s: %s" % (k, v))))
results['_parameters'] = self._load_parameters(v)
results['_parameters'] = self._load_parameters(v)
return results
def _load_loop(self, ds, k, v):
@ -264,7 +264,7 @@ LEGACY = """
def _load_invalid_key(self, ds, k, v):
''' handle any key we do not recognize '''
raise AnsibleError("%s is not a legal parameter in an Ansible task or handler" % k)
def _load_other_valid_key(self, ds, k, v):
@ -296,7 +296,7 @@ LEGACY = """
return self._load_invalid_key
else:
return self._load_other_valid_key
# ==================================================================================
# PRE-VALIDATION - expected to be uncommonly used, this checks for arguments that
# are aliases of each other. Most everything else should be in the LOAD block
@ -311,7 +311,7 @@ LEGACY = """
# =================================================================================
# POST-VALIDATION: checks for internal inconsistency between fields
# validation can result in an error but also corrections
def _post_validate(self):
''' is the loaded datastructure sane? '''
@ -321,13 +321,13 @@ LEGACY = """
# incompatible items
self._validate_conflicting_su_and_sudo()
self._validate_conflicting_first_available_file_and_loookup()
def _post_validate_fixed_name(self):
'' construct a name for the task if no name was specified '''
flat_params = " ".join(["%s=%s" % (k,v) for k,v in self._parameters.iteritems()])
return = "%s %s" % (self._module_name, flat_params)
def _post_validate_conflicting_su_and_sudo(self):
''' make sure su/sudo usage doesn't conflict '''
@ -342,4 +342,3 @@ LEGACY = """
raise AnsibleError("with_(plugin), and first_available_file are mutually incompatible in a single task")
"""