Fixing security bugs for CVE-2016-9587

(cherry picked from c8f8d0607c5c123522951835603ccb7948e663d5)
This commit is contained in:
James Cammarata 2016-12-13 11:14:47 -06:00
parent 56de9d8ae7
commit ec84ff6de6
7 changed files with 132 additions and 48 deletions

View file

@ -19,6 +19,8 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import re
from jinja2.exceptions import UndefinedError from jinja2.exceptions import UndefinedError
from ansible.compat.six import text_type from ansible.compat.six import text_type
@ -26,6 +28,9 @@ from ansible.errors import AnsibleError, AnsibleUndefinedVariable
from ansible.playbook.attribute import FieldAttribute from ansible.playbook.attribute import FieldAttribute
from ansible.template import Templar from ansible.template import Templar
from ansible.module_utils._text import to_native from ansible.module_utils._text import to_native
from ansible.vars.unsafe_proxy import wrap_var
LOOKUP_REGEX = re.compile(r'lookup\s*\(')
class Conditional: class Conditional:
@ -111,9 +116,12 @@ class Conditional:
return conditional return conditional
# a Jinja2 evaluation that results in something Python can eval! # a Jinja2 evaluation that results in something Python can eval!
if hasattr(conditional, '__UNSAFE__') and LOOKUP_REGEX.match(conditional):
raise AnsibleError("The conditional '%s' contains variables which came from an unsafe " \
"source and also contains a lookup() call, failing conditional check" % conditional)
presented = "{%% if %s %%} True {%% else %%} False {%% endif %%}" % conditional presented = "{%% if %s %%} True {%% else %%} False {%% endif %%}" % conditional
conditional = templar.template(presented) val = templar.template(presented).strip()
val = conditional.strip()
if val == "True": if val == "True":
return True return True
elif val == "False": elif val == "False":

View file

@ -30,9 +30,8 @@ import tempfile
import time import time
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from ansible.compat.six import binary_type, text_type, iteritems, with_metaclass
from ansible import constants as C from ansible import constants as C
from ansible.compat.six import binary_type, string_types, text_type, iteritems, with_metaclass
from ansible.errors import AnsibleError, AnsibleConnectionFailure from ansible.errors import AnsibleError, AnsibleConnectionFailure
from ansible.executor.module_common import modify_module from ansible.executor.module_common import modify_module
from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils._text import to_bytes, to_native, to_text
@ -40,6 +39,7 @@ from ansible.module_utils.json_utils import _filter_non_json_lines
from ansible.parsing.utils.jsonify import jsonify from ansible.parsing.utils.jsonify import jsonify
from ansible.playbook.play_context import MAGIC_VARIABLE_MAPPING from ansible.playbook.play_context import MAGIC_VARIABLE_MAPPING
from ansible.release import __version__ from ansible.release import __version__
from ansible.vars.unsafe_proxy import wrap_var
try: try:
@ -454,6 +454,8 @@ class ActionBase(with_metaclass(ABCMeta, object)):
# happens sometimes when it is a dir and not on bsd # happens sometimes when it is a dir and not on bsd
if 'checksum' not in mystat['stat']: if 'checksum' not in mystat['stat']:
mystat['stat']['checksum'] = '' mystat['stat']['checksum'] = ''
elif not isinstance(mystat['stat']['checksum'], string_types):
raise AnsibleError("Invalid checksum returned by stat: expected a string type but got %s" % type(mystat['stat']['checksum']))
return mystat['stat'] return mystat['stat']
@ -669,6 +671,39 @@ class ActionBase(with_metaclass(ABCMeta, object)):
display.debug("done with _execute_module (%s, %s)" % (module_name, module_args)) display.debug("done with _execute_module (%s, %s)" % (module_name, module_args))
return data return data
def _clean_returned_data(self, data):
remove_keys = set()
fact_keys = set(data.keys())
# first we add all of our magic variable names to the set of
# keys we want to remove from facts
for magic_var in MAGIC_VARIABLE_MAPPING:
remove_keys.update(fact_keys.intersection(MAGIC_VARIABLE_MAPPING[magic_var]))
# next we remove any connection plugin specific vars
for conn_path in self._shared_loader_obj.connection_loader.all(path_only=True):
try:
conn_name = os.path.splitext(os.path.basename(conn_path))[0]
re_key = re.compile('^ansible_%s_' % conn_name)
for fact_key in fact_keys:
if re_key.match(fact_key):
remove_keys.add(fact_key)
except AttributeError:
pass
# remove some KNOWN keys
for hard in ['ansible_rsync_path', 'ansible_playbook_python']:
if hard in fact_keys:
remove_keys.add(hard)
# finally, we search for interpreter keys to remove
re_interp = re.compile('^ansible_.*_interpreter$')
for fact_key in fact_keys:
if re_interp.match(fact_key):
remove_keys.add(fact_key)
# then we remove them (except for ssh host keys)
for r_key in remove_keys:
if not r_key.startswith('ansible_ssh_host_key_'):
del data[r_key]
def _parse_returned_data(self, res): def _parse_returned_data(self, res):
try: try:
filtered_output, warnings = _filter_non_json_lines(res.get('stdout', u'')) filtered_output, warnings = _filter_non_json_lines(res.get('stdout', u''))
@ -677,37 +712,11 @@ class ActionBase(with_metaclass(ABCMeta, object)):
data = json.loads(filtered_output) data = json.loads(filtered_output)
data['_ansible_parsed'] = True data['_ansible_parsed'] = True
if 'ansible_facts' in data and isinstance(data['ansible_facts'], dict): if 'ansible_facts' in data and isinstance(data['ansible_facts'], dict):
remove_keys = set() self._clean_returned_data(data['ansible_facts'])
fact_keys = set(data['ansible_facts'].keys()) data['ansible_facts'] = wrap_var(data['ansible_facts'])
# first we add all of our magic variable names to the set of if 'add_host' in data and isinstance(data['add_host'].get('host_vars', None), dict):
# keys we want to remove from facts self._clean_returned_data(data['add_host']['host_vars'])
for magic_var in MAGIC_VARIABLE_MAPPING: data['add_host'] = wrap_var(data['add_host'])
remove_keys.update(fact_keys.intersection(MAGIC_VARIABLE_MAPPING[magic_var]))
# next we remove any connection plugin specific vars
for conn_path in self._shared_loader_obj.connection_loader.all(path_only=True):
try:
conn_name = os.path.splitext(os.path.basename(conn_path))[0]
re_key = re.compile('^ansible_%s_' % conn_name)
for fact_key in fact_keys:
if re_key.match(fact_key):
remove_keys.add(fact_key)
except AttributeError:
pass
# remove some KNOWN keys
for hard in ['ansible_rsync_path']:
if hard in fact_keys:
remove_keys.add(hard)
# finally, we search for interpreter keys to remove
re_interp = re.compile('^ansible_.*_interpreter$')
for fact_key in fact_keys:
if re_interp.match(fact_key):
remove_keys.add(fact_key)
# then we remove them (except for ssh host keys)
for r_key in remove_keys:
if not r_key.startswith('ansible_ssh_host_key_'):
del data['ansible_facts'][r_key]
except ValueError: except ValueError:
# not valid json, lets try to capture error # not valid json, lets try to capture error
data = dict(failed=True, _ansible_parsed=False) data = dict(failed=True, _ansible_parsed=False)

View file

@ -23,6 +23,7 @@ import pwd
import time import time
from ansible import constants as C from ansible import constants as C
from ansible.compat.six import string_types
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase

View file

@ -30,8 +30,9 @@ from numbers import Number
from jinja2 import Environment from jinja2 import Environment
from jinja2.loaders import FileSystemLoader from jinja2.loaders import FileSystemLoader
from jinja2.exceptions import TemplateSyntaxError, UndefinedError from jinja2.exceptions import TemplateSyntaxError, UndefinedError
from jinja2.nodes import EvalContext
from jinja2.utils import concat as j2_concat from jinja2.utils import concat as j2_concat
from jinja2.runtime import StrictUndefined from jinja2.runtime import Context, StrictUndefined
from ansible import constants as C from ansible import constants as C
from ansible.compat.six import string_types, text_type from ansible.compat.six import string_types, text_type
@ -41,7 +42,7 @@ from ansible.template.safe_eval import safe_eval
from ansible.template.template import AnsibleJ2Template from ansible.template.template import AnsibleJ2Template
from ansible.template.vars import AnsibleJ2Vars from ansible.template.vars import AnsibleJ2Vars
from ansible.module_utils._text import to_native, to_text from ansible.module_utils._text import to_native, to_text
from ansible.vars.unsafe_proxy import UnsafeProxy, wrap_var
try: try:
from hashlib import sha1 from hashlib import sha1
@ -126,6 +127,62 @@ def _count_newlines_from_end(in_str):
# Uncommon cases: zero length string and string containing only newlines # Uncommon cases: zero length string and string containing only newlines
return i return i
class AnsibleEvalContext(EvalContext):
'''
A custom jinja2 EvalContext, which is currently unused and saved
here for possible future use.
'''
pass
class AnsibleContext(Context):
'''
A custom context, which intercepts resolve() calls and sets a flag
internally if any variable lookup returns an AnsibleUnsafe value. This
flag is checked post-templating, and (when set) will result in the
final templated result being wrapped via UnsafeProxy.
'''
def __init__(self, *args, **kwargs):
super(AnsibleContext, self).__init__(*args, **kwargs)
self.eval_ctx = AnsibleEvalContext(self.environment, self.name)
self.unsafe = False
def _is_unsafe(self, val):
'''
Our helper function, which will also recursively check dict and
list entries due to the fact that they may be repr'd and contain
a key or value which contains jinja2 syntax and would otherwise
lose the AnsibleUnsafe value.
'''
if isinstance(val, dict):
for key in val.keys():
if self._is_unsafe(val[key]):
return True
elif isinstance(val, list):
for item in val:
if self._is_unsafe(item):
return True
elif isinstance(val, string_types) and hasattr(val, '__UNSAFE__'):
return True
return False
def resolve(self, key):
'''
The intercepted resolve(), which uses the helper above to set the
internal flag whenever an unsafe variable value is returned.
'''
val = super(AnsibleContext, self).resolve(key)
if val is not None and not self.unsafe:
if self._is_unsafe(val):
self.unsafe = True
return val
class AnsibleEnvironment(Environment):
'''
Our custom environment, which simply allows us to override the class-level
values for the Template and Context classes used by jinja2 internally.
'''
context_class = AnsibleContext
template_class = AnsibleJ2Template
class Templar: class Templar:
''' '''
@ -159,14 +216,13 @@ class Templar:
self._fail_on_filter_errors = True self._fail_on_filter_errors = True
self._fail_on_undefined_errors = C.DEFAULT_UNDEFINED_VAR_BEHAVIOR self._fail_on_undefined_errors = C.DEFAULT_UNDEFINED_VAR_BEHAVIOR
self.environment = Environment( self.environment = AnsibleEnvironment(
trim_blocks=True, trim_blocks=True,
undefined=StrictUndefined, undefined=StrictUndefined,
extensions=self._get_extensions(), extensions=self._get_extensions(),
finalize=self._finalize, finalize=self._finalize,
loader=FileSystemLoader(self._basedir), loader=FileSystemLoader(self._basedir),
) )
self.environment.template_class = AnsibleJ2Template
self.SINGLE_VAR = re.compile(r"^%s\s*(\w*)\s*%s$" % (self.environment.variable_start_string, self.environment.variable_end_string)) self.SINGLE_VAR = re.compile(r"^%s\s*(\w*)\s*%s$" % (self.environment.variable_start_string, self.environment.variable_end_string))
@ -229,7 +285,7 @@ class Templar:
def _clean_data(self, orig_data): def _clean_data(self, orig_data):
''' remove jinja2 template tags from a string ''' ''' remove jinja2 template tags from a string '''
if not isinstance(orig_data, string_types) or hasattr(orig_data, '__ENCRYPTED__'): if not isinstance(orig_data, string_types) or hasattr(orig_data, '__ENCRYPTED__') or hasattr(orig_data, '__UNSAFE__'):
return orig_data return orig_data
with contextlib.closing(StringIO(orig_data)) as data: with contextlib.closing(StringIO(orig_data)) as data:
@ -292,11 +348,12 @@ class Templar:
# Don't template unsafe variables, instead drop them back down to their constituent type. # Don't template unsafe variables, instead drop them back down to their constituent type.
if hasattr(variable, '__UNSAFE__'): if hasattr(variable, '__UNSAFE__'):
if isinstance(variable, text_type): if isinstance(variable, text_type):
return self._clean_data(variable) rval = self._clean_data(variable)
else: else:
# Do we need to convert these into text_type as well? # Do we need to convert these into text_type as well?
# return self._clean_data(to_text(variable._obj, nonstring='passthru')) # return self._clean_data(to_text(variable._obj, nonstring='passthru'))
return self._clean_data(variable._obj) rval = self._clean_data(variable._obj)
return rval
try: try:
if convert_bare: if convert_bare:
@ -328,14 +385,23 @@ class Templar:
if cache and sha1_hash in self._cached_result: if cache and sha1_hash in self._cached_result:
result = self._cached_result[sha1_hash] result = self._cached_result[sha1_hash]
else: else:
result = self.do_template(variable, preserve_trailing_newlines=preserve_trailing_newlines, escape_backslashes=escape_backslashes, fail_on_undefined=fail_on_undefined, overrides=overrides) result = self.do_template(
variable,
preserve_trailing_newlines=preserve_trailing_newlines,
escape_backslashes=escape_backslashes,
fail_on_undefined=fail_on_undefined,
overrides=overrides,
)
if convert_data and not self._no_type_regex.match(variable): if convert_data and not self._no_type_regex.match(variable):
# if this looks like a dictionary or list, convert it to such using the safe_eval method # if this looks like a dictionary or list, convert it to such using the safe_eval method
if (result.startswith("{") and not result.startswith(self.environment.variable_start_string)) or \ if (result.startswith("{") and not result.startswith(self.environment.variable_start_string)) or \
result.startswith("[") or result in ("True", "False"): result.startswith("[") or result in ("True", "False"):
unsafe = hasattr(result, '__UNSAFE__')
eval_results = safe_eval(result, locals=self._available_variables, include_exceptions=True) eval_results = safe_eval(result, locals=self._available_variables, include_exceptions=True)
if eval_results[1] is None: if eval_results[1] is None:
result = eval_results[0] result = eval_results[0]
if unsafe:
result = wrap_var(result)
else: else:
# FIXME: if the safe_eval raised an error, should we do something with it? # FIXME: if the safe_eval raised an error, should we do something with it?
pass pass
@ -435,7 +501,6 @@ class Templar:
ran = None ran = None
if ran: if ran:
from ansible.vars.unsafe_proxy import UnsafeProxy, wrap_var
if wantlist: if wantlist:
ran = wrap_var(ran) ran = wrap_var(ran)
else: else:
@ -505,6 +570,8 @@ class Templar:
try: try:
res = j2_concat(rf) res = j2_concat(rf)
if new_context.unsafe:
res = wrap_var(res)
except TypeError as te: except TypeError as te:
if 'StrictUndefined' in to_native(te): if 'StrictUndefined' in to_native(te):
errmsg = "Unable to look up a name or access an attribute in template string (%s).\n" % to_native(data) errmsg = "Unable to look up a name or access an attribute in template string (%s).\n" % to_native(data)

View file

@ -33,5 +33,5 @@ class AnsibleJ2Template(jinja2.environment.Template):
''' '''
def new_context(self, vars=None, shared=False, locals=None): def new_context(self, vars=None, shared=False, locals=None):
return jinja2.runtime.Context(self.environment, vars.add_locals(locals), self.name, self.blocks) return self.environment.context_class(self.environment, vars.add_locals(locals), self.name, self.blocks)

View file

@ -82,7 +82,7 @@ class AnsibleJ2Vars:
# HostVars is special, return it as-is, as is the special variable # HostVars is special, return it as-is, as is the special variable
# 'vars', which contains the vars structure # 'vars', which contains the vars structure
from ansible.vars.hostvars import HostVars from ansible.vars.hostvars import HostVars
if isinstance(variable, dict) and varname == "vars" or isinstance(variable, HostVars): if isinstance(variable, dict) and varname == "vars" or isinstance(variable, HostVars) or hasattr(variable, '__UNSAFE__'):
return variable return variable
else: else:
value = None value = None

View file

@ -64,7 +64,6 @@ __all__ = ['UnsafeProxy', 'AnsibleUnsafe', 'AnsibleJSONUnsafeEncoder', 'AnsibleJ
class AnsibleUnsafe(object): class AnsibleUnsafe(object):
__UNSAFE__ = True __UNSAFE__ = True
class AnsibleUnsafeText(text_type, AnsibleUnsafe): class AnsibleUnsafeText(text_type, AnsibleUnsafe):
pass pass
@ -101,7 +100,7 @@ class AnsibleJSONUnsafeDecoder(json.JSONDecoder):
def _wrap_dict(v): def _wrap_dict(v):
for k in v.keys(): for k in v.keys():
if v[k] is not None: if v[k] is not None:
v[k] = wrap_var(v[k]) v[wrap_var(k)] = wrap_var(v[k])
return v return v