From 44ab948e5d3c33b1f24c368da9aaee776d98f512 Mon Sep 17 00:00:00 2001 From: Jordan Borean Date: Tue, 15 May 2018 09:31:21 +1000 Subject: [PATCH] create module tmpdir based on remote_tmp (#39833) * create module tmpdir based on remote_tmp * Source remote_tmp from controller if possible * Fixed sanity test and not use lambda * Added expansion of env vars to the remote tmp * Fixed sanity issues * Added note around shell remote_tmp option * Changed fallback tmp dir to ~/.ansible/tmp to make shell defaults --- hacking/test-module | 3 +- lib/ansible/module_utils/basic.py | 20 ++++++- lib/ansible/plugins/action/__init__.py | 27 ++++++++- test/units/module_utils/basic/test_tmpdir.py | 62 ++++++++++++++++++++ test/units/plugins/action/test_action.py | 4 ++ 5 files changed, 111 insertions(+), 5 deletions(-) create mode 100644 test/units/module_utils/basic/test_tmpdir.py diff --git a/hacking/test-module b/hacking/test-module index e2e824e9e05..7928ca9b4bf 100755 --- a/hacking/test-module +++ b/hacking/test-module @@ -125,7 +125,8 @@ def boilerplate_module(modfile, args, interpreters, check, destfile): # default selinux fs list is pass in as _ansible_selinux_special_fs arg complex_args['_ansible_selinux_special_fs'] = C.DEFAULT_SELINUX_SPECIAL_FS - complex_args['_ansible_tmpdir'] = C.DEFAULT_LOCAL_TMP + complex_args['_ansible_tmp'] = C.DEFAULT_LOCAL_TMP + comlpex_args['_ansible_keep_remote_files'] = C.DEFAULT_KEEP_REMOTE_FILES if args.startswith("@"): # Argument is a YAML file (JSON is a subset of YAML) diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index ddd308bbca8..f05dc367e89 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -41,13 +41,15 @@ PASS_VARS = { 'check_mode': 'check_mode', 'debug': '_debug', 'diff': '_diff', + 'keep_remote_files': '_keep_remote_files', 'module_name': '_name', 'no_log': 'no_log', + 'remote_tmp': '_remote_tmp', 'selinux_special_fs': '_selinux_special_fs', 'shell_executable': '_shell', 'socket': '_socket_path', 'syslog_facility': '_syslog_facility', - 'tmpdir': 'tmpdir', + 'tmpdir': '_tmpdir', 'verbosity': '_verbosity', 'version': 'ansible_version', } @@ -58,6 +60,7 @@ PASS_BOOLS = ('no_log', 'debug', 'diff') # The functions available here can be used to do many common tasks, # to simplify development of Python modules. +import atexit import locale import os import re @@ -853,6 +856,7 @@ class AnsibleModule(object): self.aliases = {} self._legal_inputs = ['_ansible_%s' % k for k in PASS_VARS] self._options_context = list() + self._tmpdir = None if add_file_common_args: for k, v in FILE_COMMON_ARGUMENTS.items(): @@ -928,6 +932,20 @@ class AnsibleModule(object): ' Update the code for this module In the future, AnsibleModule will' ' always check for invalid arguments.', version='2.9') + @property + def tmpdir(self): + # if _ansible_tmpdir was not set, the module needs to create it and + # clean it up once finished. + if self._tmpdir is None: + basedir = os.path.expanduser(os.path.expandvars(self._remote_tmp)) + basefile = "ansible-moduletmp-%s-" % time.time() + tmpdir = tempfile.mkdtemp(prefix=basefile, dir=basedir) + if not self._keep_remote_files: + atexit.register(shutil.rmtree, tmpdir) + self._tmpdir = tmpdir + + return self._tmpdir + def warn(self, warning): if isinstance(warning, string_types): diff --git a/lib/ansible/plugins/action/__init__.py b/lib/ansible/plugins/action/__init__.py index 89471cd1391..f18d911f7e7 100644 --- a/lib/ansible/plugins/action/__init__.py +++ b/lib/ansible/plugins/action/__init__.py @@ -241,7 +241,7 @@ class ActionBase(with_metaclass(ABCMeta, object)): try: remote_tmp = self._connection._shell.get_option('remote_tmp') except AnsibleError: - remote_tmp = '~/ansible' + remote_tmp = '~/.ansible/tmp' # deal with tmpdir creation basefile = 'ansible-tmp-%s-%s' % (time.time(), random.randint(0, 2**48)) @@ -650,9 +650,19 @@ class ActionBase(with_metaclass(ABCMeta, object)): # make sure all commands use the designated shell executable module_args['_ansible_shell_executable'] = self._play_context.executable - # make sure all commands use the designated temporary directory + # make sure modules are aware if they need to keep the remote files + module_args['_ansible_keep_remote_files'] = C.DEFAULT_KEEP_REMOTE_FILES + + # make sure all commands use the designated temporary directory if created module_args['_ansible_tmpdir'] = self._connection._shell.tmpdir + # make sure the remote_tmp value is sent through in case modules needs to create their own + try: + module_args['_ansible_remote_tmp'] = self._connection._shell.get_option('remote_tmp') + except KeyError: + # here for 3rd party shell plugin compatibility in case they do not define the remote_tmp option + module_args['_ansible_remote_tmp'] = '~/.ansible/tmp' + def _update_connection_options(self, options, variables=None): ''' ensures connections have the appropriate information ''' update = {} @@ -683,6 +693,18 @@ class ActionBase(with_metaclass(ABCMeta, object)): ' if they are responsible for removing it.') del delete_remote_tmp # No longer used + tmpdir = self._connection._shell.tmpdir + + # We set the module_style to new here so the remote_tmp is created + # before the module args are built if remote_tmp is needed (async). + # If the module_style turns out to not be new and we didn't create the + # remote tmp here, it will still be created. This must be done before + # calling self._update_module_args() so the module wrapper has the + # correct remote_tmp value set + if not self._is_pipelining_enabled("new", wrap_async) and tmpdir is None: + self._make_tmp_path() + tmpdir = self._connection._shell.tmpdir + if task_vars is None: task_vars = dict() @@ -700,7 +722,6 @@ class ActionBase(with_metaclass(ABCMeta, object)): if not shebang and module_style != 'binary': raise AnsibleError("module (%s) is missing interpreter line" % module_name) - tmpdir = self._connection._shell.tmpdir remote_module_path = None if not self._is_pipelining_enabled(module_style, wrap_async): diff --git a/test/units/module_utils/basic/test_tmpdir.py b/test/units/module_utils/basic/test_tmpdir.py new file mode 100644 index 00000000000..0d7996719f7 --- /dev/null +++ b/test/units/module_utils/basic/test_tmpdir.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2018 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division) +__metaclass__ = type + +import os +import shutil +import tempfile + +import pytest + +from ansible.compat.tests.mock import patch + + +class TestAnsibleModuleTmpDir: + + DATA = ( + ( + { + "_ansible_tmpdir": "/path/to/dir", + "_ansible_remote_tmp": "/path/tmpdir", + "_ansible_keep_remote_files": False, + }, + "/path/to/dir" + ), + ( + { + "_ansible_tmpdir": None, + "_ansible_remote_tmp": "/path/tmpdir", + "_ansible_keep_remote_files": False + }, + "/path/tmpdir/ansible-moduletmp-42-" + ), + ( + { + "_ansible_tmpdir": None, + "_ansible_remote_tmp": "$HOME/.test", + "_ansible_keep_remote_files": False + }, + os.path.join(os.environ['HOME'], ".test/ansible-moduletmp-42-") + ), + ) + + # pylint bug: https://github.com/PyCQA/pylint/issues/511 + # pylint: disable=undefined-variable + @pytest.mark.parametrize('stdin, expected', ((s, e) for s, e in DATA), + indirect=['stdin']) + def test_tmpdir_property(self, am, monkeypatch, expected): + def mock_mkdtemp(prefix, dir): + return os.path.join(dir, prefix) + monkeypatch.setattr(tempfile, 'mkdtemp', mock_mkdtemp) + monkeypatch.setattr(shutil, 'rmtree', lambda x: None) + + with patch('time.time', return_value=42): + actual_tmpdir = am.tmpdir + assert actual_tmpdir == expected + + # verify subsequent calls always produces the same tmpdir + assert am.tmpdir == actual_tmpdir diff --git a/test/units/plugins/action/test_action.py b/test/units/plugins/action/test_action.py index 52fa3bea1be..df90c06bf65 100644 --- a/test/units/plugins/action/test_action.py +++ b/test/units/plugins/action/test_action.py @@ -414,12 +414,16 @@ class TestActionBase(unittest.TestCase): to_run.append(arg_path) return " ".join(to_run) + def get_option(option): + return {}.get(option) + mock_connection = MagicMock() mock_connection.build_module_command.side_effect = build_module_command mock_connection.socket_path = None mock_connection._shell.get_remote_filename.return_value = 'copy.py' mock_connection._shell.join_path.side_effect = os.path.join mock_connection._shell.tmpdir = '/var/tmp/mytempdir' + mock_connection._shell.get_option = get_option # we're using a real play context here play_context = PlayContext()