From d7dd41146a5615f2e6a48c716b511311a0f03a2c Mon Sep 17 00:00:00 2001
From: James Cammarata <jimi@sngx.net>
Date: Tue, 13 Dec 2016 11:14:47 -0600
Subject: [PATCH] Fixing security bugs CVE-2016-9587 (cherry picked from
 c8f8d0607c5c123522951835603ccb7948e663d5)

---
 lib/ansible/playbook/conditional.py    | 11 +++-
 lib/ansible/plugins/action/__init__.py | 76 ++++++++++++----------
 lib/ansible/template/__init__.py       | 88 +++++++++++++++++++++++---
 lib/ansible/template/template.py       |  2 +-
 lib/ansible/template/vars.py           |  2 +-
 lib/ansible/vars/unsafe_proxy.py       |  2 +-
 6 files changed, 133 insertions(+), 48 deletions(-)

diff --git a/lib/ansible/playbook/conditional.py b/lib/ansible/playbook/conditional.py
index 5615a252b8c..bf0da78c8a1 100644
--- a/lib/ansible/playbook/conditional.py
+++ b/lib/ansible/playbook/conditional.py
@@ -19,6 +19,8 @@
 from __future__ import (absolute_import, division, print_function)
 __metaclass__ = type
 
+import re
+
 from jinja2.exceptions import UndefinedError
 
 from ansible.compat.six import text_type
@@ -26,6 +28,8 @@ from ansible.errors import AnsibleError, AnsibleUndefinedVariable
 from ansible.playbook.attribute import FieldAttribute
 from ansible.template import Templar
 
+LOOKUP_REGEX = re.compile(r'lookup\s*\(')
+
 class Conditional:
 
     '''
@@ -95,9 +99,12 @@ class Conditional:
                 return conditional
 
             # a Jinja2 evaluation that results in something Python can eval!
+            if hasattr(conditional, '__UNSAFE__') and LOOKUP_REGEX.match(conditional):
+                raise AnsibleError("The conditional '%s' contains variables which came from an unsafe " \
+                                   "source and also contains a lookup() call, failing conditional check" % conditional)
+
             presented = "{%% if %s %%} True {%% else %%} False {%% endif %%}" % conditional
-            conditional = templar.template(presented)
-            val = conditional.strip()
+            val = templar.template(presented).strip()
             if val == "True":
                 return True
             elif val == "False":
diff --git a/lib/ansible/plugins/action/__init__.py b/lib/ansible/plugins/action/__init__.py
index 12b397893fc..cde61326551 100644
--- a/lib/ansible/plugins/action/__init__.py
+++ b/lib/ansible/plugins/action/__init__.py
@@ -30,15 +30,16 @@ import tempfile
 import time
 from abc import ABCMeta, abstractmethod
 
-from ansible.compat.six import binary_type, text_type, iteritems, with_metaclass
-
 from ansible import constants as C
+from ansible.compat.six import binary_type, string_types, text_type, iteritems, with_metaclass
 from ansible.errors import AnsibleError, AnsibleConnectionFailure
 from ansible.executor.module_common import modify_module
 from ansible.playbook.play_context import MAGIC_VARIABLE_MAPPING
 from ansible.release import __version__
 from ansible.parsing.utils.jsonify import jsonify
 from ansible.utils.unicode import to_bytes, to_unicode
+from ansible.vars.unsafe_proxy import wrap_var
+
 
 try:
     from __main__ import display
@@ -436,6 +437,8 @@ class ActionBase(with_metaclass(ABCMeta, object)):
         # happens sometimes when it is a dir and not on bsd
         if not 'checksum' in mystat['stat']:
             mystat['stat']['checksum'] = ''
+        elif not isinstance(mystat['stat']['checksum'], string_types):
+            raise AnsibleError("Invalid checksum returned by stat: expected a string type but got %s" % type(mystat['stat']['checksum']))
 
         return mystat['stat']
 
@@ -683,42 +686,49 @@ class ActionBase(with_metaclass(ABCMeta, object)):
         display.debug("done with _execute_module (%s, %s)" % (module_name, module_args))
         return data
 
+    def _clean_returned_data(self, data):
+        remove_keys = set()
+        fact_keys = set(data.keys())
+        # first we add all of our magic variable names to the set of
+        # keys we want to remove from facts
+        for magic_var in MAGIC_VARIABLE_MAPPING:
+            remove_keys.update(fact_keys.intersection(MAGIC_VARIABLE_MAPPING[magic_var]))
+        # next we remove any connection plugin specific vars
+        for conn_path in self._shared_loader_obj.connection_loader.all(path_only=True):
+            try:
+                conn_name = os.path.splitext(os.path.basename(conn_path))[0]
+                re_key = re.compile('^ansible_%s_' % conn_name)
+                for fact_key in fact_keys:
+                    if re_key.match(fact_key):
+                        remove_keys.add(fact_key)
+            except AttributeError:
+                pass
+
+        # remove some KNOWN keys
+        for hard in ['ansible_rsync_path', 'ansible_playbook_python']:
+            if hard in fact_keys:
+                remove_keys.add(hard)
+
+        # finally, we search for interpreter keys to remove
+        re_interp = re.compile('^ansible_.*_interpreter$')
+        for fact_key in fact_keys:
+            if re_interp.match(fact_key):
+                remove_keys.add(fact_key)
+        # then we remove them (except for ssh host keys)
+        for r_key in remove_keys:
+            if not r_key.startswith('ansible_ssh_host_key_'):
+                del data[r_key]
+
     def _parse_returned_data(self, res):
         try:
             data = json.loads(self._filter_non_json_lines(res.get('stdout', u'')))
             data['_ansible_parsed'] = True
             if 'ansible_facts' in data and isinstance(data['ansible_facts'], dict):
-                remove_keys = set()
-                fact_keys = set(data['ansible_facts'].keys())
-                # first we add all of our magic variable names to the set of
-                # keys we want to remove from facts
-                for magic_var in MAGIC_VARIABLE_MAPPING:
-                    remove_keys.update(fact_keys.intersection(MAGIC_VARIABLE_MAPPING[magic_var]))
-                # next we remove any connection plugin specific vars
-                for conn_path in self._shared_loader_obj.connection_loader.all(path_only=True):
-                    try:
-                        conn_name = os.path.splitext(os.path.basename(conn_path))[0]
-                        re_key = re.compile('^ansible_%s_' % conn_name)
-                        for fact_key in fact_keys:
-                            if re_key.match(fact_key):
-                                remove_keys.add(fact_key)
-                    except AttributeError:
-                        pass
-
-                # remove some KNOWN keys
-                for hard in ['ansible_rsync_path']:
-                    if hard in fact_keys:
-                        remove_keys.add(hard)
-
-                # finally, we search for interpreter keys to remove
-                re_interp = re.compile('^ansible_.*_interpreter$')
-                for fact_key in fact_keys:
-                    if re_interp.match(fact_key):
-                        remove_keys.add(fact_key)
-                # then we remove them (except for ssh host keys)
-                for r_key in remove_keys:
-                    if not r_key.startswith('ansible_ssh_host_key_'):
-                        del data['ansible_facts'][r_key]
+                self._clean_returned_data(data['ansible_facts'])
+                data['ansible_facts'] = wrap_var(data['ansible_facts'])
+            if 'add_host' in data and isinstance(data['add_host'].get('host_vars', None), dict):
+                self._clean_returned_data(data['add_host']['host_vars'])
+                data['add_host'] = wrap_var(data['add_host'])
         except ValueError:
             # not valid json, lets try to capture error
             data = dict(failed=True, _ansible_parsed=False)
diff --git a/lib/ansible/template/__init__.py b/lib/ansible/template/__init__.py
index a662364565f..eddef1659e3 100644
--- a/lib/ansible/template/__init__.py
+++ b/lib/ansible/template/__init__.py
@@ -30,8 +30,9 @@ from ansible.compat.six import string_types, text_type, binary_type
 from jinja2 import Environment
 from jinja2.loaders import FileSystemLoader
 from jinja2.exceptions import TemplateSyntaxError, UndefinedError
+from jinja2.nodes import EvalContext
 from jinja2.utils import concat as j2_concat
-from jinja2.runtime import StrictUndefined
+from jinja2.runtime import Context, StrictUndefined
 
 from ansible import constants as C
 from ansible.errors import AnsibleError, AnsibleFilterError, AnsibleUndefinedVariable
@@ -40,6 +41,7 @@ from ansible.template.safe_eval import safe_eval
 from ansible.template.template import AnsibleJ2Template
 from ansible.template.vars import AnsibleJ2Vars
 from ansible.utils.unicode import to_unicode, to_str
+from ansible.vars.unsafe_proxy import UnsafeProxy, wrap_var
 
 try:
     from hashlib import sha1
@@ -124,6 +126,62 @@ def _count_newlines_from_end(in_str):
         # Uncommon cases: zero length string and string containing only newlines
         return i
 
+class AnsibleEvalContext(EvalContext):
+    '''
+    A custom jinja2 EvalContext, which is currently unused and saved
+    here for possible future use.
+    '''
+    pass
+
+class AnsibleContext(Context):
+    '''
+    A custom context, which intercepts resolve() calls and sets a flag
+    internally if any variable lookup returns an AnsibleUnsafe value. This
+    flag is checked post-templating, and (when set) will result in the
+    final templated result being wrapped via UnsafeProxy.
+    '''
+    def __init__(self, *args, **kwargs):
+        super(AnsibleContext, self).__init__(*args, **kwargs)
+        self.eval_ctx = AnsibleEvalContext(self.environment, self.name)
+        self.unsafe = False
+
+    def _is_unsafe(self, val):
+        '''
+        Our helper function, which will also recursively check dict and
+        list entries due to the fact that they may be repr'd and contain
+        a key or value which contains jinja2 syntax and would otherwise
+        lose the AnsibleUnsafe value.
+        '''
+        if isinstance(val, dict):
+            for key in val.keys():
+                if self._is_unsafe(val[key]):
+                    return True
+        elif isinstance(val, list):
+            for item in val:
+                if self._is_unsafe(item):
+                    return True
+        elif isinstance(val, string_types) and hasattr(val, '__UNSAFE__'):
+            return True
+        return False
+
+    def resolve(self, key):
+        '''
+        The intercepted resolve(), which uses the helper above to set the
+        internal flag whenever an unsafe variable value is returned.
+        '''
+        val = super(AnsibleContext, self).resolve(key)
+        if val is not None and not self.unsafe:
+            if self._is_unsafe(val):
+                self.unsafe = True
+        return val
+
+class AnsibleEnvironment(Environment):
+    '''
+    Our custom environment, which simply allows us to override the class-level
+    values for the Template and Context classes used by jinja2 internally.
+    '''
+    context_class = AnsibleContext
+    template_class = AnsibleJ2Template
 
 class Templar:
     '''
@@ -157,14 +215,13 @@ class Templar:
         self._fail_on_filter_errors    = True
         self._fail_on_undefined_errors = C.DEFAULT_UNDEFINED_VAR_BEHAVIOR
 
-        self.environment = Environment(
+        self.environment = AnsibleEnvironment(
             trim_blocks=True,
             undefined=StrictUndefined,
             extensions=self._get_extensions(),
             finalize=self._finalize,
             loader=FileSystemLoader(self._basedir),
         )
-        self.environment.template_class = AnsibleJ2Template
 
         self.SINGLE_VAR = re.compile(r"^%s\s*(\w*)\s*%s$" % (self.environment.variable_start_string, self.environment.variable_end_string))
 
@@ -227,7 +284,7 @@ class Templar:
     def _clean_data(self, orig_data):
         ''' remove jinja2 template tags from a string '''
 
-        if not isinstance(orig_data, string_types):
+        if not isinstance(orig_data, string_types) or hasattr(orig_data, '__UNSAFE__'):
             return orig_data
 
         with contextlib.closing(StringIO(orig_data)) as data:
@@ -290,11 +347,12 @@ class Templar:
         # Don't template unsafe variables, instead drop them back down to their constituent type.
         if hasattr(variable, '__UNSAFE__'):
             if isinstance(variable, text_type):
-                return self._clean_data(variable)
+                rval = self._clean_data(variable)
             else:
                 # Do we need to convert these into text_type as well?
-                # return self._clean_data(to_unicode(variable._obj, nonstring='passthru'))
-                return self._clean_data(variable._obj)
+                # return self._clean_data(to_text(variable._obj, nonstring='passthru'))
+                rval = self._clean_data(variable._obj)
+            return rval
 
         try:
             if convert_bare:
@@ -327,14 +385,23 @@ class Templar:
                     if cache and sha1_hash in self._cached_result:
                         result = self._cached_result[sha1_hash]
                     else:
-                        result = self._do_template(variable, preserve_trailing_newlines=preserve_trailing_newlines, escape_backslashes=escape_backslashes, fail_on_undefined=fail_on_undefined, overrides=overrides)
+                        result = self._do_template(
+                            variable,
+                            preserve_trailing_newlines=preserve_trailing_newlines,
+                            escape_backslashes=escape_backslashes,
+                            fail_on_undefined=fail_on_undefined,
+                            overrides=overrides,
+                        )
                         if convert_data and not self._no_type_regex.match(variable):
                             # if this looks like a dictionary or list, convert it to such using the safe_eval method
                             if (result.startswith("{") and not result.startswith(self.environment.variable_start_string)) or \
-                              result.startswith("[") or result in ("True", "False"):
+                                result.startswith("[") or result in ("True", "False"):
+                                unsafe = hasattr(result, '__UNSAFE__')
                                 eval_results = safe_eval(result, locals=self._available_variables, include_exceptions=True)
                                 if eval_results[1] is None:
                                     result = eval_results[0]
+                                    if unsafe:
+                                        result = wrap_var(result)
                                 else:
                                     # FIXME: if the safe_eval raised an error, should we do something with it?
                                     pass
@@ -421,7 +488,6 @@ class Templar:
                 ran = None
 
             if ran:
-                from ansible.vars.unsafe_proxy import UnsafeProxy, wrap_var
                 if wantlist:
                     ran = wrap_var(ran)
                 else:
@@ -492,6 +558,8 @@ class Templar:
 
             try:
                 res = j2_concat(rf)
+                if new_context.unsafe:
+                    res = wrap_var(res)
             except TypeError as te:
                 if 'StrictUndefined' in to_str(te):
                     errmsg  = "Unable to look up a name or access an attribute in template string (%s).\n" % to_str(data)
diff --git a/lib/ansible/template/template.py b/lib/ansible/template/template.py
index a111bec0a5a..55936f42f71 100644
--- a/lib/ansible/template/template.py
+++ b/lib/ansible/template/template.py
@@ -33,5 +33,5 @@ class AnsibleJ2Template(jinja2.environment.Template):
     '''
 
     def new_context(self, vars=None, shared=False, locals=None):
-        return jinja2.runtime.Context(self.environment, vars.add_locals(locals), self.name, self.blocks)
+        return self.environment.context_class(self.environment, vars.add_locals(locals), self.name, self.blocks)
 
diff --git a/lib/ansible/template/vars.py b/lib/ansible/template/vars.py
index badf93b1e86..d5a5bcfc6f7 100644
--- a/lib/ansible/template/vars.py
+++ b/lib/ansible/template/vars.py
@@ -81,7 +81,7 @@ class AnsibleJ2Vars:
         # HostVars is special, return it as-is, as is the special variable
         # 'vars', which contains the vars structure
         from ansible.vars.hostvars import HostVars
-        if isinstance(variable, dict) and varname == "vars" or isinstance(variable, HostVars):
+        if isinstance(variable, dict) and varname == "vars" or isinstance(variable, HostVars) or hasattr(variable, '__UNSAFE__'):
             return variable
         else:
             value = None
diff --git a/lib/ansible/vars/unsafe_proxy.py b/lib/ansible/vars/unsafe_proxy.py
index ac5cce24af8..211220d8a72 100644
--- a/lib/ansible/vars/unsafe_proxy.py
+++ b/lib/ansible/vars/unsafe_proxy.py
@@ -95,7 +95,7 @@ class AnsibleJSONUnsafeDecoder(json.JSONDecoder):
 def _wrap_dict(v):
     for k in v.keys():
         if v[k] is not None:
-            v[k] = wrap_var(v[k])
+            v[wrap_var(k)] = wrap_var(v[k])
     return v