Merge pull request #15344 from abadger/ziploader

Ziploader "recursive imports" and caching
This commit is contained in:
Toshio Kuratomi 2016-04-13 10:27:01 -07:00
commit 208ad36ce4
16 changed files with 579 additions and 247 deletions

View file

@ -33,6 +33,7 @@ except Exception:
pass pass
import os import os
import shutil
import sys import sys
import traceback import traceback
@ -40,6 +41,7 @@ import traceback
from multiprocessing import Lock from multiprocessing import Lock
debug_lock = Lock() debug_lock = Lock()
import ansible.constants as C
from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError
from ansible.utils.display import Display from ansible.utils.display import Display
from ansible.utils.unicode import to_unicode from ansible.utils.unicode import to_unicode
@ -87,28 +89,28 @@ if __name__ == '__main__':
cli = mycli(sys.argv) cli = mycli(sys.argv)
cli.parse() cli.parse()
sys.exit(cli.run()) exit_code = cli.run()
except AnsibleOptionsError as e: except AnsibleOptionsError as e:
cli.parser.print_help() cli.parser.print_help()
display.error(to_unicode(e), wrap_text=False) display.error(to_unicode(e), wrap_text=False)
sys.exit(5) exit_code = 5
except AnsibleParserError as e: except AnsibleParserError as e:
display.error(to_unicode(e), wrap_text=False) display.error(to_unicode(e), wrap_text=False)
sys.exit(4) exit_code = 4
# TQM takes care of these, but leaving comment to reserve the exit codes # TQM takes care of these, but leaving comment to reserve the exit codes
# except AnsibleHostUnreachable as e: # except AnsibleHostUnreachable as e:
# display.error(str(e)) # display.error(str(e))
# sys.exit(3) # exit_code = 3
# except AnsibleHostFailed as e: # except AnsibleHostFailed as e:
# display.error(str(e)) # display.error(str(e))
# sys.exit(2) # exit_code = 2
except AnsibleError as e: except AnsibleError as e:
display.error(to_unicode(e), wrap_text=False) display.error(to_unicode(e), wrap_text=False)
sys.exit(1) exit_code = 1
except KeyboardInterrupt: except KeyboardInterrupt:
display.error("User interrupted execution") display.error("User interrupted execution")
sys.exit(99) exit_code = 99
except Exception as e: except Exception as e:
have_cli_options = cli is not None and cli.options is not None have_cli_options = cli is not None and cli.options is not None
display.error("Unexpected Exception: %s" % to_unicode(e), wrap_text=False) display.error("Unexpected Exception: %s" % to_unicode(e), wrap_text=False)
@ -116,4 +118,9 @@ if __name__ == '__main__':
display.display(u"the full traceback was:\n\n%s" % to_unicode(traceback.format_exc())) display.display(u"the full traceback was:\n\n%s" % to_unicode(traceback.format_exc()))
else: else:
display.display("to see the full traceback, use -vvv") display.display("to see the full traceback, use -vvv")
sys.exit(250) exit_code = 250
finally:
# Remove ansible tempdir
shutil.rmtree(C.DEFAULT_LOCAL_TMP, True)
sys.exit(exit_code)

View file

@ -452,6 +452,22 @@ This is the default location Ansible looks to find modules::
Ansible knows how to look in multiple locations if you feed it a colon separated path, and it also will look for modules in the Ansible knows how to look in multiple locations if you feed it a colon separated path, and it also will look for modules in the
"./library" directory alongside a playbook. "./library" directory alongside a playbook.
.. _local_tmp:
local_tmp
=========
When Ansible gets ready to send a module to a remote machine it usually has to
add a few things to the module: Some boilerplate code, the module's
parameters, and a few constants from the config file. This combination of
things gets stored in a temporary file until ansible exits and cleans up after
itself. The default location is a subdirectory of the user's home directory.
If you'd like to change that, you can do so by altering this setting::
local_tmp = $HOME/.ansible/tmp
Ansible will then choose a random directory name inside this location.
.. _log_path: .. _log_path:
log_path log_path

View file

@ -14,6 +14,7 @@
#inventory = /etc/ansible/hosts #inventory = /etc/ansible/hosts
#library = /usr/share/my_modules/ #library = /usr/share/my_modules/
#remote_tmp = $HOME/.ansible/tmp #remote_tmp = $HOME/.ansible/tmp
#local_tmp = $HOME/.ansible/tmp
#forks = 5 #forks = 5
#poll_interval = 15 #poll_interval = 15
#sudo_user = root #sudo_user = root

View file

@ -29,12 +29,14 @@
# test-module -m ../library/file/lineinfile -a "dest=/etc/exports line='/srv/home hostname1(rw,sync)'" --check # test-module -m ../library/file/lineinfile -a "dest=/etc/exports line='/srv/home hostname1(rw,sync)'" --check
# test-module -m ../library/commands/command -a "echo hello" -n -o "test_hello" # test-module -m ../library/commands/command -a "echo hello" -n -o "test_hello"
import sys
import base64 import base64
from multiprocessing import Lock
import optparse
import os import os
import subprocess import subprocess
import sys
import traceback import traceback
import optparse
import ansible.utils.vars as utils_vars import ansible.utils.vars as utils_vars
from ansible.parsing.dataloader import DataLoader from ansible.parsing.dataloader import DataLoader
from ansible.parsing.utils.jsonify import jsonify from ansible.parsing.utils.jsonify import jsonify
@ -133,10 +135,12 @@ def boilerplate_module(modfile, args, interpreter, check, destfile):
modname = os.path.basename(modfile) modname = os.path.basename(modfile)
modname = os.path.splitext(modname)[0] modname = os.path.splitext(modname)[0]
action_write_lock = Lock()
(module_data, module_style, shebang) = module_common.modify_module( (module_data, module_style, shebang) = module_common.modify_module(
modname, modname,
modfile, modfile,
complex_args, complex_args,
action_write_lock,
task_vars=task_vars task_vars=task_vars
) )

View file

@ -20,6 +20,7 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import os import os
import tempfile
from string import ascii_letters, digits from string import ascii_letters, digits
from ansible.compat.six import string_types from ansible.compat.six import string_types
@ -47,7 +48,7 @@ def shell_expand(path):
path = os.path.expanduser(os.path.expandvars(path)) path = os.path.expanduser(os.path.expandvars(path))
return path return path
def get_config(p, section, key, env_var, default, boolean=False, integer=False, floating=False, islist=False, isnone=False, ispath=False, ispathlist=False): def get_config(p, section, key, env_var, default, boolean=False, integer=False, floating=False, islist=False, isnone=False, ispath=False, ispathlist=False, istmppath=False):
''' return a configuration variable with casting ''' ''' return a configuration variable with casting '''
value = _get_config(p, section, key, env_var, default) value = _get_config(p, section, key, env_var, default)
if boolean: if boolean:
@ -65,6 +66,11 @@ def get_config(p, section, key, env_var, default, boolean=False, integer=False,
value = None value = None
elif ispath: elif ispath:
value = shell_expand(value) value = shell_expand(value)
elif istmppath:
value = shell_expand(value)
if not os.path.exists(value):
os.makedirs(value, 0o700)
value = tempfile.mkdtemp(prefix='ansible-local-tmp', dir=value)
elif ispathlist: elif ispathlist:
if isinstance(value, string_types): if isinstance(value, string_types):
value = [shell_expand(x) for x in value.split(os.pathsep)] value = [shell_expand(x) for x in value.split(os.pathsep)]
@ -136,6 +142,7 @@ DEFAULT_HOST_LIST = get_config(p, DEFAULTS,'inventory', 'ANSIBLE_INVENTO
DEFAULT_MODULE_PATH = get_config(p, DEFAULTS, 'library', 'ANSIBLE_LIBRARY', None, ispathlist=True) DEFAULT_MODULE_PATH = get_config(p, DEFAULTS, 'library', 'ANSIBLE_LIBRARY', None, ispathlist=True)
DEFAULT_ROLES_PATH = get_config(p, DEFAULTS, 'roles_path', 'ANSIBLE_ROLES_PATH', '/etc/ansible/roles', ispathlist=True) DEFAULT_ROLES_PATH = get_config(p, DEFAULTS, 'roles_path', 'ANSIBLE_ROLES_PATH', '/etc/ansible/roles', ispathlist=True)
DEFAULT_REMOTE_TMP = get_config(p, DEFAULTS, 'remote_tmp', 'ANSIBLE_REMOTE_TEMP', '$HOME/.ansible/tmp') DEFAULT_REMOTE_TMP = get_config(p, DEFAULTS, 'remote_tmp', 'ANSIBLE_REMOTE_TEMP', '$HOME/.ansible/tmp')
DEFAULT_LOCAL_TMP = get_config(p, DEFAULTS, 'local_tmp', 'ANSIBLE_LOCAL_TEMP', '$HOME/.ansible/tmp', istmppath=True)
DEFAULT_MODULE_NAME = get_config(p, DEFAULTS, 'module_name', None, 'command') DEFAULT_MODULE_NAME = get_config(p, DEFAULTS, 'module_name', None, 'command')
DEFAULT_FORKS = get_config(p, DEFAULTS, 'forks', 'ANSIBLE_FORKS', 5, integer=True) DEFAULT_FORKS = get_config(p, DEFAULTS, 'forks', 'ANSIBLE_FORKS', 5, integer=True)
DEFAULT_MODULE_ARGS = get_config(p, DEFAULTS, 'module_args', 'ANSIBLE_MODULE_ARGS', '') DEFAULT_MODULE_ARGS = get_config(p, DEFAULTS, 'module_args', 'ANSIBLE_MODULE_ARGS', '')

View file

@ -20,6 +20,7 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import ast
import base64 import base64
import json import json
import os import os
@ -32,6 +33,7 @@ from ansible import __version__
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.utils.unicode import to_bytes, to_unicode from ansible.utils.unicode import to_bytes, to_unicode
from ansible.plugins.strategy import action_write_locks
try: try:
from __main__ import display from __main__ import display
@ -48,7 +50,7 @@ REPLACER_SELINUX = b"<<SELINUX_SPECIAL_FILESYSTEMS>>"
# We could end up writing out parameters with unicode characters so we need to # We could end up writing out parameters with unicode characters so we need to
# specify an encoding for the python source file # specify an encoding for the python source file
ENCODING_STRING = b'# -*- coding: utf-8 -*-' ENCODING_STRING = u'# -*- coding: utf-8 -*-'
# we've moved the module_common relative to the snippets, so fix the path # we've moved the module_common relative to the snippets, so fix the path
_SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils') _SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
@ -56,7 +58,7 @@ _SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
# ****************************************************************************** # ******************************************************************************
ZIPLOADER_TEMPLATE = u'''%(shebang)s ZIPLOADER_TEMPLATE = u'''%(shebang)s
# -*- coding: utf-8 -*-' %(coding)s
# This code is part of Ansible, but is an independent component. # This code is part of Ansible, but is an independent component.
# The code in this particular templatable string, and this templatable string # The code in this particular templatable string, and this templatable string
# only, is BSD licensed. Modules which end up using this snippet, which is # only, is BSD licensed. Modules which end up using this snippet, which is
@ -87,17 +89,49 @@ ZIPLOADER_TEMPLATE = u'''%(shebang)s
import os import os
import sys import sys
import base64 import base64
import shutil
import zipfile
import tempfile import tempfile
import subprocess import subprocess
if sys.version_info < (3,): if sys.version_info < (3,):
bytes = str bytes = str
PY3 = False
else: else:
unicode = str unicode = str
PY3 = True
try:
# Python-2.6+
from io import BytesIO as IOStream
except ImportError:
# Python < 2.6
from StringIO import StringIO as IOStream
ZIPDATA = """%(zipdata)s""" ZIPDATA = """%(zipdata)s"""
def debug(command, zipped_mod): def invoke_module(module, modlib_path, json_params):
pythonpath = os.environ.get('PYTHONPATH')
if pythonpath:
os.environ['PYTHONPATH'] = ':'.join((modlib_path, pythonpath))
else:
os.environ['PYTHONPATH'] = modlib_path
p = subprocess.Popen(['%(interpreter)s', module], env=os.environ, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
(stdout, stderr) = p.communicate(json_params)
if not isinstance(stderr, (bytes, unicode)):
stderr = stderr.read()
if not isinstance(stdout, (bytes, unicode)):
stdout = stdout.read()
if PY3:
sys.stderr.buffer.write(stderr)
sys.stdout.buffer.write(stdout)
else:
sys.stderr.write(stderr)
sys.stdout.write(stdout)
return p.returncode
def debug(command, zipped_mod, json_params):
# The code here normally doesn't run. It's only used for debugging on the # The code here normally doesn't run. It's only used for debugging on the
# remote machine. Run with ANSIBLE_KEEP_REMOTE_FILES=1 envvar and -vvv # remote machine. Run with ANSIBLE_KEEP_REMOTE_FILES=1 envvar and -vvv
# to save the module file remotely. Login to the remote machine and use # to save the module file remotely. Login to the remote machine and use
@ -105,7 +139,7 @@ def debug(command, zipped_mod):
# files. Edit the source files to instrument the code or experiment with # files. Edit the source files to instrument the code or experiment with
# different values. Then use /path/to/module execute to run the extracted # different values. Then use /path/to/module execute to run the extracted
# files you've edited instead of the actual zipped module. # files you've edited instead of the actual zipped module.
#
# Okay to use __file__ here because we're running from a kept file # Okay to use __file__ here because we're running from a kept file
basedir = os.path.dirname(__file__) basedir = os.path.dirname(__file__)
if command == 'explode': if command == 'explode':
@ -113,11 +147,11 @@ def debug(command, zipped_mod):
# print the path to the code. This is an easy way for people to look # print the path to the code. This is an easy way for people to look
# at the code on the remote machine for debugging it in that # at the code on the remote machine for debugging it in that
# environment # environment
import zipfile
z = zipfile.ZipFile(zipped_mod) z = zipfile.ZipFile(zipped_mod)
for filename in z.namelist(): for filename in z.namelist():
if filename.startswith('/'): if filename.startswith('/'):
raise Exception('Something wrong with this module zip file: should not contain absolute paths') raise Exception('Something wrong with this module zip file: should not contain absolute paths')
dest_filename = os.path.join(basedir, filename) dest_filename = os.path.join(basedir, filename)
if dest_filename.endswith(os.path.sep) and not os.path.exists(dest_filename): if dest_filename.endswith(os.path.sep) and not os.path.exists(dest_filename):
os.makedirs(dest_filename) os.makedirs(dest_filename)
@ -128,26 +162,17 @@ def debug(command, zipped_mod):
f = open(dest_filename, 'w') f = open(dest_filename, 'w')
f.write(z.read(filename)) f.write(z.read(filename))
f.close() f.close()
print('Module expanded into:') print('Module expanded into:')
print('%%s' %% os.path.join(basedir, 'ansible')) print('%%s' %% os.path.join(basedir, 'ansible'))
exitcode = 0
elif command == 'execute': elif command == 'execute':
# Execute the exploded code instead of executing the module from the # Execute the exploded code instead of executing the module from the
# embedded ZIPDATA. This allows people to easily run their modified # embedded ZIPDATA. This allows people to easily run their modified
# code on the remote machine to see how changes will affect it. # code on the remote machine to see how changes will affect it.
pythonpath = os.environ.get('PYTHONPATH') exitcode = invoke_module(os.path.join(basedir, 'ansible_module_%(ansible_module)s.py'), basedir, json_params)
if pythonpath:
os.environ['PYTHONPATH'] = ':'.join((basedir, pythonpath))
else:
os.environ['PYTHONPATH'] = basedir
p = subprocess.Popen(['%(interpreter)s', '-m', 'ansible.module_exec.%(ansible_module)s.__main__'], env=os.environ, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(stdout, stderr) = p.communicate()
if not isinstance(stderr, (bytes, unicode)):
stderr = stderr.read()
if not isinstance(stdout, (bytes, unicode)):
stdout = stdout.read()
sys.stderr.write(stderr)
sys.stdout.write(stdout)
sys.exit(p.returncode)
elif command == 'excommunicate': elif command == 'excommunicate':
# This attempts to run the module in-process (by importing a main # This attempts to run the module in-process (by importing a main
# function and then calling it). It is not the way ansible generally # function and then calling it). It is not the way ansible generally
@ -157,43 +182,78 @@ def debug(command, zipped_mod):
# when using this that are only artifacts of how we're invoking here, # when using this that are only artifacts of how we're invoking here,
# not actual bugs (as they don't affect the real way that we invoke # not actual bugs (as they don't affect the real way that we invoke
# ansible modules) # ansible modules)
sys.stdin = IOStream(json_params)
sys.path.insert(0, basedir) sys.path.insert(0, basedir)
from ansible.module_exec.%(ansible_module)s.__main__ import main from ansible_module_%(ansible_module)s import main
main() main()
print('WARNING: Module returned to wrapper instead of exiting')
os.environ['ANSIBLE_MODULE_ARGS'] = %(args)s sys.exit(1)
os.environ['ANSIBLE_MODULE_CONSTANTS'] = %(constants)s
try:
temp_fd, temp_path = tempfile.mkstemp(prefix='ansible_')
os.write(temp_fd, base64.b64decode(ZIPDATA))
if len(sys.argv) == 2:
debug(sys.argv[1], temp_path)
else: else:
pythonpath = os.environ.get('PYTHONPATH') print('WARNING: Unknown debug command. Doing nothing.')
if pythonpath: exitcode = 0
os.environ['PYTHONPATH'] = ':'.join((temp_path, pythonpath))
else:
os.environ['PYTHONPATH'] = temp_path
p = subprocess.Popen(['%(interpreter)s', '-m', 'ansible.module_exec.%(ansible_module)s.__main__'], env=os.environ, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(stdout, stderr) = p.communicate()
if not isinstance(stderr, (bytes, unicode)):
stderr = stderr.read()
if not isinstance(stdout, (bytes, unicode)):
stdout = stdout.read()
sys.stderr.write(stderr)
sys.stdout.write(stdout)
sys.exit(p.returncode)
finally: return exitcode
if __name__ == '__main__':
ZIPLOADER_PARAMS = %(params)s
if PY3:
ZIPLOADER_PARAMS = ZIPLOADER_PARAMS.encode('utf-8')
try: try:
os.close(temp_fd) temp_path = tempfile.mkdtemp(prefix='ansible_')
os.remove(temp_path) zipped_mod = os.path.join(temp_path, 'ansible_modlib.zip')
except NameError: modlib = open(zipped_mod, 'wb')
# mkstemp failed modlib.write(base64.b64decode(ZIPDATA))
pass modlib.close()
if len(sys.argv) == 2:
exitcode = debug(sys.argv[1], zipped_mod, ZIPLOADER_PARAMS)
else:
z = zipfile.ZipFile(zipped_mod)
module = os.path.join(temp_path, 'ansible_module_%(ansible_module)s.py')
f = open(module, 'wb')
f.write(z.read('ansible_module_%(ansible_module)s.py'))
f.close()
exitcode = invoke_module(module, zipped_mod, ZIPLOADER_PARAMS)
finally:
try:
shutil.rmtree(temp_path)
except OSError:
# tempdir creation probably failed
pass
sys.exit(exitcode)
''' '''
class ModuleDepFinder(ast.NodeVisitor):
# Caveats:
# This code currently does not handle:
# * relative imports from py2.6+ from . import urls
# * python packages (directories with __init__.py in them)
IMPORT_PREFIX_SIZE = len('ansible.module_utils.')
def __init__(self, *args, **kwargs):
super(ModuleDepFinder, self).__init__(*args, **kwargs)
self.module_files = set()
def visit_Import(self, node):
# import ansible.module_utils.MODLIB[.other]
for alias in (a for a in node.names if a.name.startswith('ansible.module_utils.')):
py_mod = alias.name[self.IMPORT_PREFIX_SIZE:].split('.', 1)[0]
self.module_files.add(py_mod)
self.generic_visit(node)
def visit_ImportFrom(self, node):
if node.module.startswith('ansible.module_utils'):
where_from = node.module[self.IMPORT_PREFIX_SIZE:]
# from ansible.module_utils.MODLIB[.other] import foo
if where_from:
py_mod = where_from.split('.', 1)[0]
self.module_files.add(py_mod)
else:
# from ansible.module_utils import MODLIB
for alias in node.names:
self.module_files.add(alias.name)
self.generic_visit(node)
def _strip_comments(source): def _strip_comments(source):
# Strip comments and blank lines from the wrapper # Strip comments and blank lines from the wrapper
buf = [] buf = []
@ -242,6 +302,28 @@ def _get_facility(task_vars):
facility = task_vars['ansible_syslog_facility'] facility = task_vars['ansible_syslog_facility']
return facility return facility
def recursive_finder(data, snippet_names, snippet_data, zf):
"""
Using ModuleDepFinder, make sure we have all of the module_utils files that
the module its module_utils files needs.
"""
tree = ast.parse(data)
finder = ModuleDepFinder()
finder.visit(tree)
new_snippets = set()
for snippet_name in finder.module_files.difference(snippet_names):
fname = '%s.py' % snippet_name
new_snippets.add(snippet_name)
if snippet_name not in snippet_data:
snippet_data[snippet_name] = _slurp(os.path.join(_SNIPPET_PATH, fname))
zf.writestr(os.path.join("ansible/module_utils", fname), snippet_data[snippet_name])
snippet_names.update(new_snippets)
for snippet_name in tuple(new_snippets):
recursive_finder(snippet_data[snippet_name], snippet_names, snippet_data, zf)
del snippet_data[snippet_name]
def _find_snippet_imports(module_name, module_data, module_path, module_args, task_vars, module_compression): def _find_snippet_imports(module_name, module_data, module_path, module_args, task_vars, module_compression):
""" """
Given the source of the module, convert it to a Jinja2 template to insert Given the source of the module, convert it to a Jinja2 template to insert
@ -280,59 +362,87 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
if module_style in ('old', 'non_native_want_json'): if module_style in ('old', 'non_native_want_json'):
return module_data, module_style, shebang return module_data, module_style, shebang
module_args_json = to_bytes(json.dumps(module_args))
output = BytesIO() output = BytesIO()
lines = module_data.split(b'\n')
snippet_names = set() snippet_names = set()
if module_substyle == 'python': if module_substyle == 'python':
# ziploader for new-style python classes # ziploader for new-style python classes
python_repred_args = to_bytes(repr(module_args_json))
constants = dict( constants = dict(
SELINUX_SPECIAL_FS=C.DEFAULT_SELINUX_SPECIAL_FS, SELINUX_SPECIAL_FS=C.DEFAULT_SELINUX_SPECIAL_FS,
SYSLOG_FACILITY=_get_facility(task_vars), SYSLOG_FACILITY=_get_facility(task_vars),
) )
python_repred_constants = to_bytes(repr(json.dumps(constants)), errors='strict') params = dict(ANSIBLE_MODULE_ARGS=module_args,
ANSIBLE_MODULE_CONSTANTS=constants,
)
#python_repred_args = to_bytes(repr(module_args_json))
#python_repred_constants = to_bytes(repr(json.dumps(constants)), errors='strict')
python_repred_params = to_bytes(repr(json.dumps(params)), errors='strict')
try: try:
compression_method = getattr(zipfile, module_compression) compression_method = getattr(zipfile, module_compression)
except AttributeError: except AttributeError:
display.warning(u'Bad module compression string specified: %s. Using ZIP_STORED (no compression)' % module_compression) display.warning(u'Bad module compression string specified: %s. Using ZIP_STORED (no compression)' % module_compression)
compression_method = zipfile.ZIP_STORED compression_method = zipfile.ZIP_STORED
zipoutput = BytesIO()
zf = zipfile.ZipFile(zipoutput, mode='w', compression=compression_method)
zf.writestr('ansible/__init__.py', b''.join((b"__version__ = '", to_bytes(__version__), b"'\n")))
zf.writestr('ansible/module_utils/__init__.py', b'')
zf.writestr('ansible/module_exec/__init__.py', b'')
zf.writestr('ansible/module_exec/%s/__init__.py' % module_name, b"") lookup_path = os.path.join(C.DEFAULT_LOCAL_TMP, 'ziploader_cache')
final_data = [] if not os.path.exists(lookup_path):
os.mkdir(lookup_path)
cached_module_filename = os.path.join(lookup_path, "%s-%s" % (module_name, module_compression))
for line in lines: zipdata = None
if line.startswith(b'from ansible.module_utils.'): # Optimization -- don't lock if the module has already been cached
tokens=line.split(b".") if os.path.exists(cached_module_filename):
snippet_name = tokens[2].split()[0] zipdata = open(cached_module_filename, 'rb').read()
snippet_names.add(snippet_name) # Fool the check later... I think we should just remove the check
fname = to_unicode(snippet_name + b".py") snippet_names.add('basic')
zf.writestr(os.path.join("ansible/module_utils", fname), _slurp(os.path.join(_SNIPPET_PATH, fname))) else:
final_data.append(line) with action_write_locks[module_name]:
else: # Check that no other process has created this while we were
final_data.append(line) # waiting for the lock
if not os.path.exists(cached_module_filename):
# Create the module zip data
zipoutput = BytesIO()
zf = zipfile.ZipFile(zipoutput, mode='w', compression=compression_method)
zf.writestr('ansible/__init__.py', b''.join((b"__version__ = '", to_bytes(__version__), b"'\n")))
zf.writestr('ansible/module_utils/__init__.py', b'')
zf.writestr('ansible/module_exec/%s/__main__.py' % module_name, b"\n".join(final_data)) zf.writestr('ansible_module_%s.py' % module_name, module_data)
zf.close()
snippet_data = dict()
recursive_finder(module_data, snippet_names, snippet_data, zf)
zf.close()
zipdata = base64.b64encode(zipoutput.getvalue())
# Write the assembled module to a temp file (write to temp
# so that no one looking for the file reads a partially
# written file)
with open(cached_module_filename + '-part', 'w') as f:
f.write(zipdata)
# Rename the file into its final position in the cache so
# future users of this module can read it off the
# filesystem instead of constructing from scratch.
os.rename(cached_module_filename + '-part', cached_module_filename)
if zipdata is None:
# Another process wrote the file while we were waiting for
# the write lock. Go ahead and read the data from disk
# instead of re-creating it.
zipdata = open(cached_module_filename, 'rb').read()
# Fool the check later... I think we should just remove the check
snippet_names.add('basic')
shebang, interpreter = _get_shebang(u'/usr/bin/python', task_vars) shebang, interpreter = _get_shebang(u'/usr/bin/python', task_vars)
if shebang is None: if shebang is None:
shebang = u'#!/usr/bin/python' shebang = u'#!/usr/bin/python'
output.write(to_bytes(STRIPPED_ZIPLOADER_TEMPLATE % dict( output.write(to_bytes(STRIPPED_ZIPLOADER_TEMPLATE % dict(
zipdata=base64.b64encode(zipoutput.getvalue()), zipdata=zipdata,
ansible_module=module_name, ansible_module=module_name,
args=python_repred_args, #args=python_repred_args,
constants=python_repred_constants, #constants=python_repred_constants,
params=python_repred_params,
shebang=shebang, shebang=shebang,
interpreter=interpreter, interpreter=interpreter,
coding=ENCODING_STRING,
))) )))
module_data = output.getvalue() module_data = output.getvalue()
@ -340,11 +450,12 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
# modules that use ziploader may implement their own helpers and not # modules that use ziploader may implement their own helpers and not
# need basic.py. All the constants that we substituted into basic.py # need basic.py. All the constants that we substituted into basic.py
# for module_replacer are now available in other, better ways. # for module_replacer are now available in other, better ways.
if b'basic' not in snippet_names: if 'basic' not in snippet_names:
raise AnsibleError("missing required import in %s: Did not import ansible.module_utils.basic for boilerplate helper code" % module_path) raise AnsibleError("missing required import in %s: Did not import ansible.module_utils.basic for boilerplate helper code" % module_path)
elif module_substyle == 'powershell': elif module_substyle == 'powershell':
# Module replacer for jsonargs and windows # Module replacer for jsonargs and windows
lines = module_data.split(b'\n')
for line in lines: for line in lines:
if REPLACER_WINDOWS in line: if REPLACER_WINDOWS in line:
ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1")) ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1"))
@ -353,6 +464,8 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
continue continue
output.write(line + b'\n') output.write(line + b'\n')
module_data = output.getvalue() module_data = output.getvalue()
module_args_json = to_bytes(json.dumps(module_args))
module_data = module_data.replace(REPLACER_JSONARGS, module_args_json) module_data = module_data.replace(REPLACER_JSONARGS, module_args_json)
# Sanity check from 1.x days. This is currently useless as we only # Sanity check from 1.x days. This is currently useless as we only
@ -363,11 +476,14 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path) raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path)
elif module_substyle == 'jsonargs': elif module_substyle == 'jsonargs':
module_args_json = to_bytes(json.dumps(module_args))
# these strings could be included in a third-party module but # these strings could be included in a third-party module but
# officially they were included in the 'basic' snippet for new-style # officially they were included in the 'basic' snippet for new-style
# python modules (which has been replaced with something else in # python modules (which has been replaced with something else in
# ziploader) If we remove them from jsonargs-style module replacer # ziploader) If we remove them from jsonargs-style module replacer
# then we can remove them everywhere. # then we can remove them everywhere.
python_repred_args = to_bytes(repr(module_args_json))
module_data = module_data.replace(REPLACER_VERSION, to_bytes(repr(__version__))) module_data = module_data.replace(REPLACER_VERSION, to_bytes(repr(__version__)))
module_data = module_data.replace(REPLACER_COMPLEX, python_repred_args) module_data = module_data.replace(REPLACER_COMPLEX, python_repred_args)
module_data = module_data.replace(REPLACER_SELINUX, to_bytes(','.join(C.DEFAULT_SELINUX_SPECIAL_FS))) module_data = module_data.replace(REPLACER_SELINUX, to_bytes(','.join(C.DEFAULT_SELINUX_SPECIAL_FS)))
@ -409,17 +525,6 @@ def modify_module(module_name, module_path, module_args, task_vars=dict(), modul
which results in the inclusion of the common code from powershell.ps1 which results in the inclusion of the common code from powershell.ps1
""" """
### TODO: Optimization ideas if this code is actually a source of slowness:
# * Fix comment stripping: Currently doesn't preserve shebangs and encoding info (but we unconditionally add encoding info)
# * Use pyminifier if installed
# * comment stripping/pyminifier needs to have config setting to turn it
# off for debugging purposes (goes along with keep remote but should be
# separate otherwise users wouldn't be able to get info on what the
# minifier output)
# * Only split into lines and recombine into strings once
# * Cache the modified module? If only the args are different and we do
# that as the last step we could cache all the work up to that point.
with open(module_path, 'rb') as f: with open(module_path, 'rb') as f:
# read in the module source # read in the module source
@ -440,7 +545,7 @@ def modify_module(module_name, module_path, module_args, task_vars=dict(), modul
lines[0] = shebang = new_shebang lines[0] = shebang = new_shebang
if os.path.basename(interpreter).startswith(b'python'): if os.path.basename(interpreter).startswith(b'python'):
lines.insert(1, ENCODING_STRING) lines.insert(1, to_bytes(ENCODING_STRING))
else: else:
# No shebang, assume a binary module? # No shebang, assume a binary module?
pass pass

View file

@ -95,12 +95,9 @@ class WorkerProcess(multiprocessing.Process):
def run(self): def run(self):
''' '''
Called when the process is started, and loops indefinitely Called when the process is started. Pushes the result onto the
until an error is encountered (typically an IOerror from the results queue. We also remove the host from the blocked hosts list, to
queue pipe being disconnected). During the loop, we attempt signify that they are ready for their next task.
to pull tasks off the job queue and run them, pushing the result
onto the results queue. We also remove the host from the blocked
hosts list, to signify that they are ready for their next task.
''' '''
if HAS_ATFORK: if HAS_ATFORK:

View file

@ -223,23 +223,6 @@ from ansible import __version__
# Backwards compat. New code should just import and use __version__ # Backwards compat. New code should just import and use __version__
ANSIBLE_VERSION = __version__ ANSIBLE_VERSION = __version__
try:
# MODULE_COMPLEX_ARGS is an old name kept for backwards compat
MODULE_COMPLEX_ARGS = os.environ.pop('ANSIBLE_MODULE_ARGS')
except KeyError:
# This file might be used for its utility functions. So don't fail if
# running outside of a module environment (will fail in _load_params()
# instead)
MODULE_COMPLEX_ARGS = None
try:
# ARGS are for parameters given in the playbook. Constants are for things
# that ansible needs to configure controller side but are passed to all
# modules.
MODULE_CONSTANTS = os.environ.pop('ANSIBLE_MODULE_CONSTANTS')
except KeyError:
MODULE_CONSTANTS = None
FILE_COMMON_ARGUMENTS=dict( FILE_COMMON_ARGUMENTS=dict(
src = dict(), src = dict(),
mode = dict(type='raw'), mode = dict(type='raw'),
@ -560,7 +543,6 @@ class AnsibleModule(object):
if k not in self.argument_spec: if k not in self.argument_spec:
self.argument_spec[k] = v self.argument_spec[k] = v
self._load_constants()
self._load_params() self._load_params()
self._set_fallbacks() self._set_fallbacks()
@ -1452,32 +1434,29 @@ class AnsibleModule(object):
continue continue
def _load_params(self): def _load_params(self):
''' read the input and set the params attribute''' ''' read the input and set the params attribute. Sets the constants as well.'''
if MODULE_COMPLEX_ARGS is None: # Avoid tracebacks when locale is non-utf8
# This helper used too early for fail_json to work. if sys.version_info < (3,):
print('{"msg": "Error: ANSIBLE_MODULE_ARGS not found in environment. Unable to figure out what parameters were passed", "failed": true}') buffer = sys.stdin.read()
sys.exit(1)
params = json_dict_unicode_to_bytes(json.loads(MODULE_COMPLEX_ARGS))
if params is None:
params = dict()
self.params = params
def _load_constants(self):
''' read the input and set the constants attribute'''
if MODULE_CONSTANTS is None:
# This helper used too early for fail_json to work.
print('{"msg": "Error: ANSIBLE_MODULE_CONSTANTS not found in environment. Unable to figure out what constants were passed", "failed": true}')
sys.exit(1)
# Make constants into "native string"
if sys.version_info >= (3,):
constants = json_dict_bytes_to_unicode(json.loads(MODULE_CONSTANTS))
else: else:
constants = json_dict_unicode_to_bytes(json.loads(MODULE_CONSTANTS)) buffer = sys.stdin.buffer.read()
if constants is None: try:
constants = dict() params = json.loads(buffer.decode('utf-8'))
self.constants = constants except ValueError:
# This helper used too early for fail_json to work.
print('{"msg": "Error: Module unable to decode valid JSON on stdin. Unable to figure out what parameters were passed", "failed": true}')
sys.exit(1)
if sys.version_info < (3,):
params = json_dict_unicode_to_bytes(params)
try:
self.params = params['ANSIBLE_MODULE_ARGS']
self.constants = params['ANSIBLE_MODULE_CONSTANTS']
except KeyError:
# This helper used too early for fail_json to work.
print('{"msg": "Error: Module unable to locate ANSIBLE_MODULE_ARGS and ANSIBLE_MODULE_CONSTANTS in json data from stdin. Unable to figure out what parameters were passed", "failed": true}')
sys.exit(1)
def _log_to_syslog(self, msg): def _log_to_syslog(self, msg):
if HAS_SYSLOG: if HAS_SYSLOG:

View file

@ -19,15 +19,16 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from ansible.compat.six.moves import queue as Queue
from ansible.compat.six import iteritems, text_type, string_types
import json import json
import time import time
import zlib import zlib
from collections import defaultdict
from multiprocessing import Lock
from jinja2.exceptions import UndefinedError from jinja2.exceptions import UndefinedError
from ansible.compat.six.moves import queue as Queue
from ansible.compat.six import iteritems, text_type, string_types
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable
from ansible.executor.play_iterator import PlayIterator from ansible.executor.play_iterator import PlayIterator
@ -51,6 +52,8 @@ except ImportError:
__all__ = ['StrategyBase'] __all__ = ['StrategyBase']
action_write_locks = defaultdict(Lock)
# TODO: this should probably be in the plugins/__init__.py, with # TODO: this should probably be in the plugins/__init__.py, with
# a smarter mechanism to set all of the attributes based on # a smarter mechanism to set all of the attributes based on
@ -141,6 +144,20 @@ class StrategyBase:
display.debug("entering _queue_task() for %s/%s" % (host, task)) display.debug("entering _queue_task() for %s/%s" % (host, task))
# Add a write lock for tasks.
# Maybe this should be added somewhere further up the call stack but
# this is the earliest in the code where we have task (1) extracted
# into its own variable and (2) there's only a single code path
# leading to the module being run. This is called by three
# functions: __init__.py::_do_handler_run(), linear.py::run(), and
# free.py::run() so we'd have to add to all three to do it there.
# The next common higher level is __init__.py::run() and that has
# tasks inside of play_iterator so we'd have to extract them to do it
# there.
if not action_write_locks[task.action]:
display.warning('Python defaultdict did not create the Lock for us. Creating manually')
action_write_locks[task.action] = Lock()
# and then queue the new task # and then queue the new task
display.debug("%s - putting task (%s) in queue" % (host, task)) display.debug("%s - putting task (%s) in queue" % (host, task))
try: try:

View file

@ -22,18 +22,38 @@ __metaclass__ = type
import sys import sys
import json import json
from io import BytesIO, StringIO
from ansible.compat.six import PY3
from ansible.utils.unicode import to_bytes
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import MagicMock from ansible.compat.tests.mock import MagicMock
class TestModuleUtilsBasic(unittest.TestCase): class TestModuleUtilsBasic(unittest.TestCase):
def setUp(self):
self.real_stdin = sys.stdin
args = json.dumps(
dict(
ANSIBLE_MODULE_ARGS=dict(
foo=False, bar=[1,2,3], bam="bam", baz=u'baz'),
ANSIBLE_MODULE_CONSTANTS=dict()
)
)
if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
def tearDown(self):
sys.stdin = self.real_stdin
@unittest.skipIf(sys.version_info[0] >= 3, "Python 3 is not supported on targets (yet)") @unittest.skipIf(sys.version_info[0] >= 3, "Python 3 is not supported on targets (yet)")
def test_module_utils_basic__log_invocation(self): def test_module_utils_basic__log_invocation(self):
from ansible.module_utils import basic from ansible.module_utils import basic
# test basic log invocation # test basic log invocation
basic.MODULE_COMPLEX_ARGS = json.dumps(dict(foo=False, bar=[1,2,3], bam="bam", baz=u'baz'))
basic.MODULE_CONSTANTS = '{}'
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec=dict( argument_spec=dict(
foo = dict(default=True, type='bool'), foo = dict(default=True, type='bool'),

View file

@ -23,8 +23,10 @@ __metaclass__ = type
import copy import copy
import json import json
import sys import sys
from io import BytesIO from io import BytesIO, StringIO
from ansible.compat.six import PY3
from ansible.utils.unicode import to_bytes
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.module_utils import basic from ansible.module_utils import basic
@ -37,9 +39,13 @@ empty_invocation = {u'module_args': {}}
class TestAnsibleModuleExitJson(unittest.TestCase): class TestAnsibleModuleExitJson(unittest.TestCase):
def setUp(self): def setUp(self):
self.COMPLEX_ARGS = basic.MODULE_COMPLEX_ARGS self.old_stdin = sys.stdin
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
self.old_stdout = sys.stdout self.old_stdout = sys.stdout
self.fake_stream = BytesIO() self.fake_stream = BytesIO()
@ -48,8 +54,8 @@ class TestAnsibleModuleExitJson(unittest.TestCase):
self.module = basic.AnsibleModule(argument_spec=dict()) self.module = basic.AnsibleModule(argument_spec=dict())
def tearDown(self): def tearDown(self):
basic.MODULE_COMPLEX_ARGS = self.COMPLEX_ARGS
sys.stdout = self.old_stdout sys.stdout = self.old_stdout
sys.stdin = self.old_stdin
def test_exit_json_no_args_exits(self): def test_exit_json_no_args_exits(self):
with self.assertRaises(SystemExit) as ctx: with self.assertRaises(SystemExit) as ctx:
@ -118,19 +124,31 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase):
) )
def setUp(self): def setUp(self):
self.COMPLEX_ARGS = basic.MODULE_COMPLEX_ARGS self.old_stdin = sys.stdin
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
self.old_stdout = sys.stdout self.old_stdout = sys.stdout
def tearDown(self): def tearDown(self):
basic.MODULE_COMPLEX_ARGS = self.COMPLEX_ARGS sys.stdin = self.old_stdin
sys.stdout = self.old_stdout sys.stdout = self.old_stdout
def test_exit_json_removes_values(self): def test_exit_json_removes_values(self):
self.maxDiff = None self.maxDiff = None
for args, return_val, expected in self.dataset: for args, return_val, expected in self.dataset:
sys.stdout = BytesIO() sys.stdout = BytesIO()
basic.MODULE_COMPLEX_ARGS = json.dumps(args) params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={})
basic.MODULE_CONSTANTS = '{}' params = json.dumps(params)
if PY3:
sys.stdin = StringIO(params)
sys.stdin.buffer = BytesIO(to_bytes(params))
else:
sys.stdin = BytesIO(to_bytes(params))
module = basic.AnsibleModule( module = basic.AnsibleModule(
argument_spec = dict( argument_spec = dict(
username=dict(), username=dict(),
@ -149,8 +167,13 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase):
del expected['changed'] del expected['changed']
expected['failed'] = True expected['failed'] = True
sys.stdout = BytesIO() sys.stdout = BytesIO()
basic.MODULE_COMPLEX_ARGS = json.dumps(args) params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={})
basic.MODULE_CONSTANTS = '{}' params = json.dumps(params)
if PY3:
sys.stdin = StringIO(params)
sys.stdin.buffer = BytesIO(to_bytes(params))
else:
sys.stdin = BytesIO(to_bytes(params))
module = basic.AnsibleModule( module = basic.AnsibleModule(
argument_spec = dict( argument_spec = dict(
username=dict(), username=dict(),

View file

@ -21,7 +21,12 @@ from __future__ import (absolute_import, division)
__metaclass__ = type __metaclass__ = type
import sys import sys
import json
import syslog import syslog
from io import BytesIO, StringIO
from ansible.compat.six import PY3
from ansible.utils.unicode import to_bytes
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock from ansible.compat.tests.mock import patch, MagicMock
@ -41,10 +46,14 @@ except ImportError:
class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase): class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.complex_args_token = basic.MODULE_COMPLEX_ARGS args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
self.constants_sentinel = basic.MODULE_CONSTANTS self.real_stdin = sys.stdin
basic.MODULE_COMPLEX_ARGS = '{}' if PY3:
basic.MODULE_CONSTANTS = '{}' sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
self.am = basic.AnsibleModule( self.am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -55,8 +64,7 @@ class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase):
basic.has_journal = False basic.has_journal = False
def tearDown(self): def tearDown(self):
basic.MODULE_COMPLEX_ARGS = self.complex_args_token sys.stdin = self.real_stdin
basic.MODULE_CONSTANTS = self.constants_sentinel
basic.has_journal = self.has_journal basic.has_journal = self.has_journal
def test_smoketest_syslog(self): def test_smoketest_syslog(self):
@ -75,17 +83,21 @@ class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase):
class TestAnsibleModuleJournaldSmokeTest(unittest.TestCase): class TestAnsibleModuleJournaldSmokeTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.complex_args_token = basic.MODULE_COMPLEX_ARGS args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
self.constants_sentinel = basic.MODULE_CONSTANTS self.real_stdin = sys.stdin
basic.MODULE_COMPLEX_ARGS = '{}' if PY3:
basic.MODULE_CONSTANTS = '{}' sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
self.am = basic.AnsibleModule( self.am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
def tearDown(self): def tearDown(self):
basic.MODULE_COMPLEX_ARGS = self.complex_args_token sys.stdin = self.real_stdin
basic.MODULE_CONSTANTS = self.constants_sentinel
@unittest.skipUnless(basic.has_journal, 'python systemd bindings not installed') @unittest.skipUnless(basic.has_journal, 'python systemd bindings not installed')
def test_smoketest_journal(self): def test_smoketest_journal(self):
@ -121,10 +133,15 @@ class TestAnsibleModuleLogSyslog(unittest.TestCase):
} }
def setUp(self): def setUp(self):
self.complex_args_token = basic.MODULE_COMPLEX_ARGS args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
self.constants_sentinel = basic.MODULE_CONSTANTS self.real_stdin = sys.stdin
basic.MODULE_COMPLEX_ARGS = '{}' if PY3:
basic.MODULE_CONSTANTS = '{}' sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
self.am = basic.AnsibleModule( self.am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -134,8 +151,7 @@ class TestAnsibleModuleLogSyslog(unittest.TestCase):
basic.has_journal = False basic.has_journal = False
def tearDown(self): def tearDown(self):
basic.MODULE_COMPLEX_ARGS = self.complex_args_token sys.stdin = self.real_stdin
basic.MODULE_CONSTANTS = self.constants_sentinel
basic.has_journal = self.has_journal basic.has_journal = self.has_journal
@patch('syslog.syslog', autospec=True) @patch('syslog.syslog', autospec=True)
@ -176,10 +192,14 @@ class TestAnsibleModuleLogJournal(unittest.TestCase):
} }
def setUp(self): def setUp(self):
self.complex_args_token = basic.MODULE_COMPLEX_ARGS args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
self.constants_sentinel = basic.MODULE_CONSTANTS self.real_stdin = sys.stdin
basic.MODULE_COMPLEX_ARGS = '{}' if PY3:
basic.MODULE_CONSTANTS = '{}' sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
self.am = basic.AnsibleModule( self.am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -198,8 +218,8 @@ class TestAnsibleModuleLogJournal(unittest.TestCase):
self._fake_out_reload(basic) self._fake_out_reload(basic)
def tearDown(self): def tearDown(self):
basic.MODULE_COMPLEX_ARGS = self.complex_args_token sys.stdin = self.real_stdin
basic.MODULE_CONSTANTS = self.constants_sentinel
basic.has_journal = self.has_journal basic.has_journal = self.has_journal
if self.module_patcher: if self.module_patcher:
self.module_patcher.stop() self.module_patcher.stop()

View file

@ -20,9 +20,13 @@ from __future__ import (absolute_import, division)
__metaclass__ = type __metaclass__ = type
import errno import errno
import json
import sys import sys
import time import time
from io import BytesIO from io import BytesIO, StringIO
from ansible.compat.six import PY3
from ansible.utils.unicode import to_bytes
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import call, MagicMock, Mock, patch, sentinel from ansible.compat.tests.mock import call, MagicMock, Mock, patch, sentinel
@ -61,8 +65,12 @@ class TestAnsibleModuleRunCommand(unittest.TestCase):
if path == '/inaccessible': if path == '/inaccessible':
raise OSError(errno.EPERM, "Permission denied: '/inaccessible'") raise OSError(errno.EPERM, "Permission denied: '/inaccessible'")
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
self.module = AnsibleModule(argument_spec=dict()) self.module = AnsibleModule(argument_spec=dict())
self.module.fail_json = MagicMock(side_effect=SystemExit) self.module.fail_json = MagicMock(side_effect=SystemExit)

View file

@ -20,16 +20,32 @@
from __future__ import (absolute_import, division) from __future__ import (absolute_import, division)
__metaclass__ = type __metaclass__ = type
from ansible.compat.tests import unittest import sys
import json
from io import BytesIO, StringIO
from ansible.compat.tests import unittest
from ansible.compat.six import PY3
from ansible.utils.unicode import to_bytes
class TestAnsibleModuleExitJson(unittest.TestCase): class TestAnsibleModuleExitJson(unittest.TestCase):
def setUp(self):
self.real_stdin = sys.stdin
def tearDown(self):
sys.stdin = self.real_stdin
def test_module_utils_basic_safe_eval(self): def test_module_utils_basic_safe_eval(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec=dict(), argument_spec=dict(),
) )

View file

@ -21,14 +21,19 @@ from __future__ import (absolute_import, division)
__metaclass__ = type __metaclass__ = type
import errno import errno
import json
import os import os
import sys import sys
from io import BytesIO, StringIO
try: try:
import builtins import builtins
except ImportError: except ImportError:
import __builtin__ as builtins import __builtin__ as builtins
from ansible.compat.six import PY3
from ansible.utils.unicode import to_bytes
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock, mock_open, Mock, call from ansible.compat.tests.mock import patch, MagicMock, mock_open, Mock, call
@ -37,10 +42,10 @@ realimport = builtins.__import__
class TestModuleUtilsBasic(unittest.TestCase): class TestModuleUtilsBasic(unittest.TestCase):
def setUp(self): def setUp(self):
pass self.real_stdin = sys.stdin
def tearDown(self): def tearDown(self):
pass sys.stdin = self.real_stdin
def clear_modules(self, mods): def clear_modules(self, mods):
for mod in mods: for mod in mods:
@ -266,8 +271,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_creation(self): def test_module_utils_basic_ansible_module_creation(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec=dict(), argument_spec=dict(),
) )
@ -282,8 +292,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
req_to = (('bam', 'baz'),) req_to = (('bam', 'baz'),)
# should test ok # should test ok
basic.MODULE_COMPLEX_ARGS = '{"foo":"hello"}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo": "hello"}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = arg_spec, argument_spec = arg_spec,
mutually_exclusive = mut_ex, mutually_exclusive = mut_ex,
@ -297,8 +312,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
# FIXME: add asserts here to verify the basic config # FIXME: add asserts here to verify the basic config
# fail, because a required param was not specified # fail, because a required param was not specified
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
self.assertRaises( self.assertRaises(
SystemExit, SystemExit,
basic.AnsibleModule, basic.AnsibleModule,
@ -312,8 +332,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
) )
# fail because of mutually exclusive parameters # fail because of mutually exclusive parameters
basic.MODULE_COMPLEX_ARGS = '{"foo":"hello", "bar": "bad", "bam": "bad"}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
self.assertRaises( self.assertRaises(
SystemExit, SystemExit,
basic.AnsibleModule, basic.AnsibleModule,
@ -327,8 +352,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
) )
# fail because a param required due to another param was not specified # fail because a param required due to another param was not specified
basic.MODULE_COMPLEX_ARGS = '{"bam":"bad"}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
self.assertRaises( self.assertRaises(
SystemExit, SystemExit,
basic.AnsibleModule, basic.AnsibleModule,
@ -344,8 +374,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_load_file_common_arguments(self): def test_module_utils_basic_ansible_module_load_file_common_arguments(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -394,8 +429,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_mls_enabled(self): def test_module_utils_basic_ansible_module_selinux_mls_enabled(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -415,8 +455,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_initial_context(self): def test_module_utils_basic_ansible_module_selinux_initial_context(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -430,8 +475,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_enabled(self): def test_module_utils_basic_ansible_module_selinux_enabled(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -463,8 +513,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_default_context(self): def test_module_utils_basic_ansible_module_selinux_default_context(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -500,8 +555,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_context(self): def test_module_utils_basic_ansible_module_selinux_context(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -543,8 +603,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_is_special_selinux_path(self): def test_module_utils_basic_ansible_module_is_special_selinux_path(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={"SELINUX_SPECIAL_FS": "nfs,nfsd,foos"}))
basic.MODULE_CONSTANTS = '{"SELINUX_SPECIAL_FS": "nfs,nfsd,foos"}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -585,8 +650,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_to_filesystem_str(self): def test_module_utils_basic_ansible_module_to_filesystem_str(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -597,8 +667,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_user_and_group(self): def test_module_utils_basic_ansible_module_user_and_group(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -613,8 +688,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_find_mount_point(self): def test_module_utils_basic_ansible_module_find_mount_point(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -638,8 +718,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_context_if_different(self): def test_module_utils_basic_ansible_module_set_context_if_different(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -684,8 +769,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_owner_if_different(self): def test_module_utils_basic_ansible_module_set_owner_if_different(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -724,7 +814,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_group_if_different(self): def test_module_utils_basic_ansible_module_set_group_if_different(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -763,8 +859,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_mode_if_different(self): def test_module_utils_basic_ansible_module_set_mode_if_different(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -852,8 +953,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )
@ -1031,8 +1137,13 @@ class TestModuleUtilsBasic(unittest.TestCase):
from ansible.module_utils import basic from ansible.module_utils import basic
basic.MODULE_COMPLEX_ARGS = '{}' args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
basic.MODULE_CONSTANTS = '{}' if PY3:
sys.stdin = StringIO(args)
sys.stdin.buffer = BytesIO(to_bytes(args))
else:
sys.stdin = BytesIO(to_bytes(args))
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )

View file

@ -212,14 +212,15 @@ class TestActionBase(unittest.TestCase):
# test python module formatting # test python module formatting
with patch.object(builtins, 'open', mock_open(read_data=to_bytes(python_module_replacers.strip(), encoding='utf-8'))) as m: with patch.object(builtins, 'open', mock_open(read_data=to_bytes(python_module_replacers.strip(), encoding='utf-8'))) as m:
mock_task.args = dict(a=1, foo='fö〩') with patch.object(os, 'rename') as m:
mock_connection.module_implementation_preferences = ('',) mock_task.args = dict(a=1, foo='fö〩')
(style, shebang, data) = action_base._configure_module(mock_task.action, mock_task.args) mock_connection.module_implementation_preferences = ('',)
self.assertEqual(style, "new") (style, shebang, data) = action_base._configure_module(mock_task.action, mock_task.args)
self.assertEqual(shebang, b"#!/usr/bin/python") self.assertEqual(style, "new")
self.assertEqual(shebang, b"#!/usr/bin/python")
# test module not found # test module not found
self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args) self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args)
# test powershell module formatting # test powershell module formatting
with patch.object(builtins, 'open', mock_open(read_data=to_bytes(powershell_module_replacers.strip(), encoding='utf-8'))) as m: with patch.object(builtins, 'open', mock_open(read_data=to_bytes(powershell_module_replacers.strip(), encoding='utf-8'))) as m: