From ec84ff6de6eca9224bf3f22b752bb8da806611ed Mon Sep 17 00:00:00 2001 From: James Cammarata Date: Tue, 13 Dec 2016 11:14:47 -0600 Subject: [PATCH] Fixing security bugs for CVE-2016-9587 (cherry picked from c8f8d0607c5c123522951835603ccb7948e663d5) --- lib/ansible/playbook/conditional.py | 12 +++- lib/ansible/plugins/action/__init__.py | 75 +++++++++++++---------- lib/ansible/plugins/action/template.py | 1 + lib/ansible/template/__init__.py | 85 +++++++++++++++++++++++--- lib/ansible/template/template.py | 2 +- lib/ansible/template/vars.py | 2 +- lib/ansible/vars/unsafe_proxy.py | 3 +- 7 files changed, 132 insertions(+), 48 deletions(-) diff --git a/lib/ansible/playbook/conditional.py b/lib/ansible/playbook/conditional.py index e0c426df3be..99e377c4db1 100644 --- a/lib/ansible/playbook/conditional.py +++ b/lib/ansible/playbook/conditional.py @@ -19,6 +19,8 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import re + from jinja2.exceptions import UndefinedError from ansible.compat.six import text_type @@ -26,6 +28,9 @@ from ansible.errors import AnsibleError, AnsibleUndefinedVariable from ansible.playbook.attribute import FieldAttribute from ansible.template import Templar from ansible.module_utils._text import to_native +from ansible.vars.unsafe_proxy import wrap_var + +LOOKUP_REGEX = re.compile(r'lookup\s*\(') class Conditional: @@ -111,9 +116,12 @@ class Conditional: return conditional # a Jinja2 evaluation that results in something Python can eval! + if hasattr(conditional, '__UNSAFE__') and LOOKUP_REGEX.match(conditional): + raise AnsibleError("The conditional '%s' contains variables which came from an unsafe " \ + "source and also contains a lookup() call, failing conditional check" % conditional) + presented = "{%% if %s %%} True {%% else %%} False {%% endif %%}" % conditional - conditional = templar.template(presented) - val = conditional.strip() + val = templar.template(presented).strip() if val == "True": return True elif val == "False": diff --git a/lib/ansible/plugins/action/__init__.py b/lib/ansible/plugins/action/__init__.py index 0c6a5c31a45..60795b60ee3 100644 --- a/lib/ansible/plugins/action/__init__.py +++ b/lib/ansible/plugins/action/__init__.py @@ -30,9 +30,8 @@ import tempfile import time from abc import ABCMeta, abstractmethod -from ansible.compat.six import binary_type, text_type, iteritems, with_metaclass - from ansible import constants as C +from ansible.compat.six import binary_type, string_types, text_type, iteritems, with_metaclass from ansible.errors import AnsibleError, AnsibleConnectionFailure from ansible.executor.module_common import modify_module from ansible.module_utils._text import to_bytes, to_native, to_text @@ -40,6 +39,7 @@ from ansible.module_utils.json_utils import _filter_non_json_lines from ansible.parsing.utils.jsonify import jsonify from ansible.playbook.play_context import MAGIC_VARIABLE_MAPPING from ansible.release import __version__ +from ansible.vars.unsafe_proxy import wrap_var try: @@ -454,6 +454,8 @@ class ActionBase(with_metaclass(ABCMeta, object)): # happens sometimes when it is a dir and not on bsd if 'checksum' not in mystat['stat']: mystat['stat']['checksum'] = '' + elif not isinstance(mystat['stat']['checksum'], string_types): + raise AnsibleError("Invalid checksum returned by stat: expected a string type but got %s" % type(mystat['stat']['checksum'])) return mystat['stat'] @@ -669,6 +671,39 @@ class ActionBase(with_metaclass(ABCMeta, object)): display.debug("done with _execute_module (%s, %s)" % (module_name, module_args)) return data + def _clean_returned_data(self, data): + remove_keys = set() + fact_keys = set(data.keys()) + # first we add all of our magic variable names to the set of + # keys we want to remove from facts + for magic_var in MAGIC_VARIABLE_MAPPING: + remove_keys.update(fact_keys.intersection(MAGIC_VARIABLE_MAPPING[magic_var])) + # next we remove any connection plugin specific vars + for conn_path in self._shared_loader_obj.connection_loader.all(path_only=True): + try: + conn_name = os.path.splitext(os.path.basename(conn_path))[0] + re_key = re.compile('^ansible_%s_' % conn_name) + for fact_key in fact_keys: + if re_key.match(fact_key): + remove_keys.add(fact_key) + except AttributeError: + pass + + # remove some KNOWN keys + for hard in ['ansible_rsync_path', 'ansible_playbook_python']: + if hard in fact_keys: + remove_keys.add(hard) + + # finally, we search for interpreter keys to remove + re_interp = re.compile('^ansible_.*_interpreter$') + for fact_key in fact_keys: + if re_interp.match(fact_key): + remove_keys.add(fact_key) + # then we remove them (except for ssh host keys) + for r_key in remove_keys: + if not r_key.startswith('ansible_ssh_host_key_'): + del data[r_key] + def _parse_returned_data(self, res): try: filtered_output, warnings = _filter_non_json_lines(res.get('stdout', u'')) @@ -677,37 +712,11 @@ class ActionBase(with_metaclass(ABCMeta, object)): data = json.loads(filtered_output) data['_ansible_parsed'] = True if 'ansible_facts' in data and isinstance(data['ansible_facts'], dict): - remove_keys = set() - fact_keys = set(data['ansible_facts'].keys()) - # first we add all of our magic variable names to the set of - # keys we want to remove from facts - for magic_var in MAGIC_VARIABLE_MAPPING: - remove_keys.update(fact_keys.intersection(MAGIC_VARIABLE_MAPPING[magic_var])) - # next we remove any connection plugin specific vars - for conn_path in self._shared_loader_obj.connection_loader.all(path_only=True): - try: - conn_name = os.path.splitext(os.path.basename(conn_path))[0] - re_key = re.compile('^ansible_%s_' % conn_name) - for fact_key in fact_keys: - if re_key.match(fact_key): - remove_keys.add(fact_key) - except AttributeError: - pass - - # remove some KNOWN keys - for hard in ['ansible_rsync_path']: - if hard in fact_keys: - remove_keys.add(hard) - - # finally, we search for interpreter keys to remove - re_interp = re.compile('^ansible_.*_interpreter$') - for fact_key in fact_keys: - if re_interp.match(fact_key): - remove_keys.add(fact_key) - # then we remove them (except for ssh host keys) - for r_key in remove_keys: - if not r_key.startswith('ansible_ssh_host_key_'): - del data['ansible_facts'][r_key] + self._clean_returned_data(data['ansible_facts']) + data['ansible_facts'] = wrap_var(data['ansible_facts']) + if 'add_host' in data and isinstance(data['add_host'].get('host_vars', None), dict): + self._clean_returned_data(data['add_host']['host_vars']) + data['add_host'] = wrap_var(data['add_host']) except ValueError: # not valid json, lets try to capture error data = dict(failed=True, _ansible_parsed=False) diff --git a/lib/ansible/plugins/action/template.py b/lib/ansible/plugins/action/template.py index c7ca48e2a32..f12141f1940 100644 --- a/lib/ansible/plugins/action/template.py +++ b/lib/ansible/plugins/action/template.py @@ -23,6 +23,7 @@ import pwd import time from ansible import constants as C +from ansible.compat.six import string_types from ansible.errors import AnsibleError from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.plugins.action import ActionBase diff --git a/lib/ansible/template/__init__.py b/lib/ansible/template/__init__.py index 422e7e68fce..4e24fbebfef 100644 --- a/lib/ansible/template/__init__.py +++ b/lib/ansible/template/__init__.py @@ -30,8 +30,9 @@ from numbers import Number from jinja2 import Environment from jinja2.loaders import FileSystemLoader from jinja2.exceptions import TemplateSyntaxError, UndefinedError +from jinja2.nodes import EvalContext from jinja2.utils import concat as j2_concat -from jinja2.runtime import StrictUndefined +from jinja2.runtime import Context, StrictUndefined from ansible import constants as C from ansible.compat.six import string_types, text_type @@ -41,7 +42,7 @@ from ansible.template.safe_eval import safe_eval from ansible.template.template import AnsibleJ2Template from ansible.template.vars import AnsibleJ2Vars from ansible.module_utils._text import to_native, to_text - +from ansible.vars.unsafe_proxy import UnsafeProxy, wrap_var try: from hashlib import sha1 @@ -126,6 +127,62 @@ def _count_newlines_from_end(in_str): # Uncommon cases: zero length string and string containing only newlines return i +class AnsibleEvalContext(EvalContext): + ''' + A custom jinja2 EvalContext, which is currently unused and saved + here for possible future use. + ''' + pass + +class AnsibleContext(Context): + ''' + A custom context, which intercepts resolve() calls and sets a flag + internally if any variable lookup returns an AnsibleUnsafe value. This + flag is checked post-templating, and (when set) will result in the + final templated result being wrapped via UnsafeProxy. + ''' + def __init__(self, *args, **kwargs): + super(AnsibleContext, self).__init__(*args, **kwargs) + self.eval_ctx = AnsibleEvalContext(self.environment, self.name) + self.unsafe = False + + def _is_unsafe(self, val): + ''' + Our helper function, which will also recursively check dict and + list entries due to the fact that they may be repr'd and contain + a key or value which contains jinja2 syntax and would otherwise + lose the AnsibleUnsafe value. + ''' + if isinstance(val, dict): + for key in val.keys(): + if self._is_unsafe(val[key]): + return True + elif isinstance(val, list): + for item in val: + if self._is_unsafe(item): + return True + elif isinstance(val, string_types) and hasattr(val, '__UNSAFE__'): + return True + return False + + def resolve(self, key): + ''' + The intercepted resolve(), which uses the helper above to set the + internal flag whenever an unsafe variable value is returned. + ''' + val = super(AnsibleContext, self).resolve(key) + if val is not None and not self.unsafe: + if self._is_unsafe(val): + self.unsafe = True + return val + +class AnsibleEnvironment(Environment): + ''' + Our custom environment, which simply allows us to override the class-level + values for the Template and Context classes used by jinja2 internally. + ''' + context_class = AnsibleContext + template_class = AnsibleJ2Template class Templar: ''' @@ -159,14 +216,13 @@ class Templar: self._fail_on_filter_errors = True self._fail_on_undefined_errors = C.DEFAULT_UNDEFINED_VAR_BEHAVIOR - self.environment = Environment( + self.environment = AnsibleEnvironment( trim_blocks=True, undefined=StrictUndefined, extensions=self._get_extensions(), finalize=self._finalize, loader=FileSystemLoader(self._basedir), ) - self.environment.template_class = AnsibleJ2Template self.SINGLE_VAR = re.compile(r"^%s\s*(\w*)\s*%s$" % (self.environment.variable_start_string, self.environment.variable_end_string)) @@ -229,7 +285,7 @@ class Templar: def _clean_data(self, orig_data): ''' remove jinja2 template tags from a string ''' - if not isinstance(orig_data, string_types) or hasattr(orig_data, '__ENCRYPTED__'): + if not isinstance(orig_data, string_types) or hasattr(orig_data, '__ENCRYPTED__') or hasattr(orig_data, '__UNSAFE__'): return orig_data with contextlib.closing(StringIO(orig_data)) as data: @@ -292,11 +348,12 @@ class Templar: # Don't template unsafe variables, instead drop them back down to their constituent type. if hasattr(variable, '__UNSAFE__'): if isinstance(variable, text_type): - return self._clean_data(variable) + rval = self._clean_data(variable) else: # Do we need to convert these into text_type as well? # return self._clean_data(to_text(variable._obj, nonstring='passthru')) - return self._clean_data(variable._obj) + rval = self._clean_data(variable._obj) + return rval try: if convert_bare: @@ -328,14 +385,23 @@ class Templar: if cache and sha1_hash in self._cached_result: result = self._cached_result[sha1_hash] else: - result = self.do_template(variable, preserve_trailing_newlines=preserve_trailing_newlines, escape_backslashes=escape_backslashes, fail_on_undefined=fail_on_undefined, overrides=overrides) + result = self.do_template( + variable, + preserve_trailing_newlines=preserve_trailing_newlines, + escape_backslashes=escape_backslashes, + fail_on_undefined=fail_on_undefined, + overrides=overrides, + ) if convert_data and not self._no_type_regex.match(variable): # if this looks like a dictionary or list, convert it to such using the safe_eval method if (result.startswith("{") and not result.startswith(self.environment.variable_start_string)) or \ result.startswith("[") or result in ("True", "False"): + unsafe = hasattr(result, '__UNSAFE__') eval_results = safe_eval(result, locals=self._available_variables, include_exceptions=True) if eval_results[1] is None: result = eval_results[0] + if unsafe: + result = wrap_var(result) else: # FIXME: if the safe_eval raised an error, should we do something with it? pass @@ -435,7 +501,6 @@ class Templar: ran = None if ran: - from ansible.vars.unsafe_proxy import UnsafeProxy, wrap_var if wantlist: ran = wrap_var(ran) else: @@ -505,6 +570,8 @@ class Templar: try: res = j2_concat(rf) + if new_context.unsafe: + res = wrap_var(res) except TypeError as te: if 'StrictUndefined' in to_native(te): errmsg = "Unable to look up a name or access an attribute in template string (%s).\n" % to_native(data) diff --git a/lib/ansible/template/template.py b/lib/ansible/template/template.py index a111bec0a5a..55936f42f71 100644 --- a/lib/ansible/template/template.py +++ b/lib/ansible/template/template.py @@ -33,5 +33,5 @@ class AnsibleJ2Template(jinja2.environment.Template): ''' def new_context(self, vars=None, shared=False, locals=None): - return jinja2.runtime.Context(self.environment, vars.add_locals(locals), self.name, self.blocks) + return self.environment.context_class(self.environment, vars.add_locals(locals), self.name, self.blocks) diff --git a/lib/ansible/template/vars.py b/lib/ansible/template/vars.py index efb1ce734ff..0e9b99e884a 100644 --- a/lib/ansible/template/vars.py +++ b/lib/ansible/template/vars.py @@ -82,7 +82,7 @@ class AnsibleJ2Vars: # HostVars is special, return it as-is, as is the special variable # 'vars', which contains the vars structure from ansible.vars.hostvars import HostVars - if isinstance(variable, dict) and varname == "vars" or isinstance(variable, HostVars): + if isinstance(variable, dict) and varname == "vars" or isinstance(variable, HostVars) or hasattr(variable, '__UNSAFE__'): return variable else: value = None diff --git a/lib/ansible/vars/unsafe_proxy.py b/lib/ansible/vars/unsafe_proxy.py index 4238d4d8e54..426410ab611 100644 --- a/lib/ansible/vars/unsafe_proxy.py +++ b/lib/ansible/vars/unsafe_proxy.py @@ -64,7 +64,6 @@ __all__ = ['UnsafeProxy', 'AnsibleUnsafe', 'AnsibleJSONUnsafeEncoder', 'AnsibleJ class AnsibleUnsafe(object): __UNSAFE__ = True - class AnsibleUnsafeText(text_type, AnsibleUnsafe): pass @@ -101,7 +100,7 @@ class AnsibleJSONUnsafeDecoder(json.JSONDecoder): def _wrap_dict(v): for k in v.keys(): if v[k] is not None: - v[k] = wrap_var(v[k]) + v[wrap_var(k)] = wrap_var(v[k]) return v