Add module common code to allow it to be easier to indicate whether arguments are mutually exclusive, required in conjunction, or whether one of a list of arguments is required. This simplifies writing Python modules.

This commit is contained in:
Michael DeHaan 2012-08-11 18:13:29 -04:00
parent 98c350a6ac
commit 1e4d45af1e
4 changed files with 61 additions and 52 deletions

View file

@ -55,11 +55,14 @@ except ImportError:
class AnsibleModule(object):
def __init__(self, argument_spec, bypass_checks=False, no_log=False, check_invalid_arguments=True):
def __init__(self, argument_spec, bypass_checks=False, no_log=False,
check_invalid_arguments=True, mutually_exclusive=None, required_together=None,
required_one_of=None):
'''
common code for quickly building an ansible module in Python
(although you can write modules in anything that can return JSON)
see library/slurp and others for examples
see library/* for examples
'''
self.argument_spec = argument_spec
@ -77,12 +80,14 @@ class AnsibleModule(object):
if not bypass_checks:
self._check_required_arguments()
self._check_argument_types()
self._check_mutually_exclusive(mutually_exclusive)
self._check_required_together(required_together)
self._check_required_one_of(required_one_of)
self._set_defaults(pre=False)
if not no_log:
self._log_invocation()
def _handle_aliases(self):
for (k,v) in self.argument_spec.iteritems():
self._legal_inputs.append(k)
@ -106,6 +111,39 @@ class AnsibleModule(object):
if k not in self._legal_inputs:
self.fail_json(msg="unsupported parameter for module: %s" % k)
def _count_terms(self, check):
count = 0
for term in check:
if term in self.params:
count += 1
return count
def _check_mutually_exclusive(self, spec):
if spec is None:
return
for check in spec:
count = self._count_terms(check)
if count > 1:
self.fail_json(msg="parameters are mutually exclusive: %s" % check)
def _check_required_one_of(self, spec):
if spec is None:
return
for check in spec:
count = self._count_terms(check)
if count == 0:
self.fail_json(msg="one of the following is required: %s" % check)
def _check_required_together(self, spec):
if spec is None:
return
for check in spec:
counts = [ self.count_terms([field]) for field in check ]
non_zero = [ c for c in counts if c > 0 ]
if len(non_zero) > 0:
if 0 in counts:
self.fail_json(msg="parameters are required together: %s" % check)
def _check_required_arguments(self):
''' ensure all required arguments are present '''
missing = []

View file

@ -90,7 +90,11 @@ def main():
virtualenv=dict(default=None, required=False)
)
module = AnsibleModule(argument_spec=arg_spec)
module = AnsibleModule(
argument_spec=arg_spec,
required_one_of=[['name','requirements']],
mutually_exclusive=[['name','requirements']],
)
rc = 0
err = ''
@ -115,38 +119,19 @@ def main():
command_map = dict(present='install', absent='uninstall', latest='install')
if state == 'latest' and version is not None:
module.fail_json(msg='If `state` is set to `latest` the `version` '
'parameter must not be specified.')
module.fail_json(msg='version is incompatible with state=latest')
if state == 'latest' and requirements is not None:
module.fail_json(msg='If `state` is set to `latest` the `requirements` '
'parameter must not be specified.')
module.fail_json(msg='requirements is incompatible with state=latest')
if name is not None and '==' in name:
module.fail_json(msg='It looks like you specified the version number '
'in the library name. Use the `version` parameter '
'to specify version instead')
if version is not None and name is None:
module.fail_json(msg='The `version` parameter must be used with the '
'`name` parameter and not with the `requirements` '
'paramter')
if name is None and requirements is None:
module.fail_json(msg='You must specify a python library name via '
'the `name` parameter or a requirements file via '
'the `requirements` paramter')
if name and requirements:
module.fail_json(msg='Both `name` and `requirements` were specified. '
'Specify only the python library name via the '
'`name` parameter or a requirements file via the '
'`requirements` parameter')
if name is not None and '=' in name:
module.fail_json(msg='versions must be specified in the version= parameter')
cmd = None
installed = None
if requirements:
cmd = '%s %s -r %s --use-mirrors' % (pip, command_map[state], requirements)
rc_pip, out_pip, err_pip = _run(cmd)
@ -158,8 +143,8 @@ def main():
(not _did_install(out) and state == 'absent'))
if name and state == 'latest':
cmd = '%s %s %s --upgrade' % (pip, command_map[state], name)
cmd = '%s %s %s --upgrade' % (pip, command_map[state], name)
rc_pip, out_pip, err_pip = _run(cmd)
rc += rc_pip
@ -169,8 +154,8 @@ def main():
changed = 'Successfully installed' in out_pip
elif name:
installed = _is_package_installed(name, pip, version)
changed = ((installed and state == 'absent') or
(not installed and state == 'present'))
@ -188,8 +173,7 @@ def main():
cmd = cmd + ' --use-mirrors'
rc_pip, out_pip, err_pip = _run(cmd)
rc += rc_pip
rc += rc_pip
out += out_pip
err += err_pip

0
library/shell Normal file → Executable file
View file

View file

@ -35,7 +35,6 @@ def is_installed(repoq, pkgspec, qf=def_qf):
rc,out,err = run(cmd)
if rc == 0:
return [ p for p in out.split('\n') if p.strip() ]
return []
def is_available(repoq, pkgspec, qf=def_qf):
@ -43,10 +42,8 @@ def is_available(repoq, pkgspec, qf=def_qf):
rc,out,err = run(cmd)
if rc == 0:
return [ p for p in out.split('\n') if p.strip() ]
return []
def is_update(repoq, pkgspec, qf=def_qf):
cmd = repoq + ["--pkgnarrow=updates", "--qf", qf, pkgspec]
rc,out,err = run(cmd)
@ -55,17 +52,14 @@ def is_update(repoq, pkgspec, qf=def_qf):
return []
def what_provides(repoq, req_spec, qf=def_qf):
cmd = repoq + ["--qf", qf, "--whatprovides", req_spec]
rc,out,err = run(cmd)
ret = []
if rc == 0:
ret = set([ p for p in out.split('\n') if p.strip() ])
return ret
def pkg_to_dict(pkgstr):
if pkgstr.strip():
n,e,v,r,a,repo = pkgstr.split('|')
@ -80,7 +74,7 @@ def pkg_to_dict(pkgstr):
'version':v,
'repo':repo,
'nevra': '%s:%s-%s-%s.%s' % (e,n,v,r,a)
}
}
if repo == 'installed':
d['yumstate'] = 'installed'
@ -95,7 +89,6 @@ def repolist(repoq, qf="%{repoid}"):
ret = []
if rc == 0:
ret = set([ p for p in out.split('\n') if p.strip() ])
return ret
def list_stuff(conf_file, stuff):
@ -128,7 +121,6 @@ def run(command):
rc = 1
err = traceback.format_exc()
out = ''
if out is None:
out = ''
if err is None:
@ -429,15 +421,13 @@ def main():
state=dict(default='installed', choices=['absent','present','installed','removed','latest']),
list=dict(),
conf_file=dict(default=None),
)
),
required_one_of = [['pkg','list']],
mutually_exclusive = [['pkg','list']]
)
params = module.params
if params['list'] and params['pkg']:
module.fail_json(msg="expected 'list=' or 'name=', but not both")
if params['list']:
if not os.path.exists(repoquery):
module.fail_json(msg="%s is required to use list= with this module. Please install the yum-utils package." % repoquery)
@ -446,12 +436,9 @@ def main():
else:
pkg = params['pkg']
if pkg is None:
module.fail_json(msg="expected 'list=' or 'name='")
else:
state = params['state']
res = ensure(module, state, pkg, params['conf_file'])
module.fail_json(msg="we should never get here unless this all failed", **res)
state = params['state']
res = ensure(module, state, pkg, params['conf_file'])
module.fail_json(msg="we should never get here unless this all failed", **res)
# this is magic, see lib/ansible/module_common.py
#<<INCLUDE_ANSIBLE_MODULE_COMMON>>