From 400a3b984e834a3d357f1c695d18d8b24d2e392a Mon Sep 17 00:00:00 2001
From: James Cammarata <jimi@sngx.net>
Date: Thu, 12 Jan 2017 09:40:19 -0600
Subject: [PATCH] Additional security fixes for CVE-2016-9587

(cherry picked from commit b7cdc21aee7584bd33c3b3d7856397bc927e88b5)
---
 lib/ansible/playbook/conditional.py | 50 +++++++++++++++++++++++++----
 lib/ansible/template/__init__.py    | 18 +++++++++--
 2 files changed, 59 insertions(+), 9 deletions(-)

diff --git a/lib/ansible/playbook/conditional.py b/lib/ansible/playbook/conditional.py
index e42036bebb6..844f8ee3867 100644
--- a/lib/ansible/playbook/conditional.py
+++ b/lib/ansible/playbook/conditional.py
@@ -19,6 +19,7 @@
 from __future__ import (absolute_import, division, print_function)
 __metaclass__ = type
 
+import ast
 import re
 
 from jinja2.exceptions import UndefinedError
@@ -29,6 +30,7 @@ from ansible.playbook.attribute import FieldAttribute
 from ansible.template import Templar
 
 LOOKUP_REGEX = re.compile(r'lookup\s*\(')
+VALID_VAR_REGEX = re.compile("^[_A-Za-z][_a-zA-Z0-9]*$")
 
 class Conditional:
 
@@ -87,23 +89,59 @@ class Conditional:
         if conditional is None or conditional == '':
             return True
 
-        if conditional in all_vars and re.match("^[_A-Za-z][_a-zA-Z0-9]*$", conditional):
+        # pull the "bare" var out, which allows for nested conditionals
+        # and things like:
+        # - assert:
+        #     that:
+        #     - item
+        #   with_items:
+        #   - 1 == 1
+        if conditional in all_vars and VALID_VAR_REGEX.match(conditional):
             conditional = all_vars[conditional]
 
         # make sure the templar is using the variables specified with this method
         templar.set_available_variables(variables=all_vars)
 
         try:
-            conditional = templar.template(conditional)
+            # if the conditional is "unsafe", disable lookups
+            disable_lookups = hasattr(conditional, '__UNSAFE__')
+            conditional = templar.template(conditional, disable_lookups=disable_lookups)
             if not isinstance(conditional, text_type) or conditional == "":
                 return conditional
 
-            # a Jinja2 evaluation that results in something Python can eval!
-            disable_lookups = False
-            if hasattr(conditional, '__UNSAFE__'):
-                disable_lookups = True
+            # update the lookups flag, as the string returned above may now be unsafe
+            # and we don't want future templating calls to do unsafe things
+            disable_lookups |= hasattr(conditional, '__UNSAFE__')
+
+            # now we generated the "presented" string, which is a jinja2 if/else block
+            # used to evaluate the conditional. First, we do some low-level jinja2 parsing
+            # involving the AST format of the statement to ensure we don't do anything
+            # unsafe (using the disable_lookup flag above)
+            e = templar.environment.overlay()
+            e.filters.update(templar._get_filters())
+            e.tests.update(templar._get_tests())
 
             presented = "{%% if %s %%} True {%% else %%} False {%% endif %%}" % conditional
+            res = e._parse(presented, None, None)
+            res = e._generate(res, None, None, defer_init=True)
+            parsed = ast.parse(res, mode='exec')
+
+            class CleansingNodeVisitor(ast.NodeVisitor):
+                def generic_visit(self, node, inside_call=False):
+                    if isinstance(node, ast.Call):
+                        inside_call = True
+                    elif isinstance(node, ast.Str):
+                        # calling things with a dunder is generally bad at this point...
+                        if inside_call and disable_lookups and node.s.startswith("__"):
+                            raise AnsibleError("Invalid access found in the presented conditional: '%s'" % conditional)
+                    # iterate over all child nodes
+                    for child_node in ast.iter_child_nodes(node):
+                        self.generic_visit(child_node, inside_call=inside_call)
+
+            cnv = CleansingNodeVisitor()
+            cnv.visit(parsed)
+
+            # and finally we templated the presented string and look at the resulting string
             val = templar.template(presented, disable_lookups=disable_lookups).strip()
             if val == "True":
                 return True
diff --git a/lib/ansible/template/__init__.py b/lib/ansible/template/__init__.py
index 609dc81aba4..747029a4551 100644
--- a/lib/ansible/template/__init__.py
+++ b/lib/ansible/template/__init__.py
@@ -324,7 +324,7 @@ class Templar:
         self._available_variables = variables
         self._cached_result       = {}
 
-    def template(self, variable, convert_bare=False, preserve_trailing_newlines=True, escape_backslashes=True, fail_on_undefined=None, overrides=None, convert_data=True, static_vars = [''], cache = True, bare_deprecated=True, disable_lookups=False):
+    def template(self, variable, convert_bare=False, preserve_trailing_newlines=True, escape_backslashes=True, fail_on_undefined=None, overrides=None, convert_data=True, static_vars=[''], cache=True, bare_deprecated=True, disable_lookups=False):
         '''
         Templates (possibly recursively) any given data as input. If convert_bare is
         set to True, the given data will be wrapped as a jinja2 variable ('{{foo}}')
@@ -407,14 +407,26 @@ class Templar:
                 return result
 
             elif isinstance(variable, (list, tuple)):
-                return [self.template(v, preserve_trailing_newlines=preserve_trailing_newlines, fail_on_undefined=fail_on_undefined, overrides=overrides) for v in variable]
+                return [self.template(
+                            v,
+                            preserve_trailing_newlines=preserve_trailing_newlines,
+                            fail_on_undefined=fail_on_undefined,
+                            overrides=overrides,
+                            disable_lookups=disable_lookups,
+                        ) for v in variable]
             elif isinstance(variable, dict):
                 d = {}
                 # we don't use iteritems() here to avoid problems if the underlying dict
                 # changes sizes due to the templating, which can happen with hostvars
                 for k in variable.keys():
                     if k not in static_vars:
-                        d[k] = self.template(variable[k], preserve_trailing_newlines=preserve_trailing_newlines, fail_on_undefined=fail_on_undefined, overrides=overrides)
+                        d[k] = self.template(
+                                   variable[k],
+                                   preserve_trailing_newlines=preserve_trailing_newlines,
+                                   fail_on_undefined=fail_on_undefined,
+                                   overrides=overrides,
+                                   disable_lookups=disable_lookups,
+                               )
                     else:
                         d[k] = variable[k]
                 return d