Fixes to safe_eval

This commit is contained in:
James Cammarata 2014-03-31 17:33:40 -05:00
parent 6a1dcca4be
commit a4df906fc9

View file

@ -29,6 +29,7 @@ from ansible.utils.plugins import *
from ansible.utils import template from ansible.utils import template
from ansible.callbacks import display from ansible.callbacks import display
import ansible.constants as C import ansible.constants as C
import ast
import time import time
import StringIO import StringIO
import stat import stat
@ -974,51 +975,95 @@ def is_list_of_strings(items):
return False return False
return True return True
def safe_eval(str, locals=None, include_exceptions=False): def safe_eval(expr, locals={}, include_exceptions=False):
''' '''
this is intended for allowing things like: this is intended for allowing things like:
with_items: a_list_variable with_items: a_list_variable
where Jinja2 would return a string where Jinja2 would return a string
but we do not want to allow it to call functions (outside of Jinja2, where but we do not want to allow it to call functions (outside of Jinja2, where
the env is constrained) the env is constrained)
Based on:
http://stackoverflow.com/questions/12523516/using-ast-and-whitelists-to-make-pythons-eval-safe
''' '''
# FIXME: is there a more native way to do this?
def is_set(var): # this is the whitelist of AST nodes we are going to
return not var.startswith("$") and not '{{' in var # allow in the evaluation. Any node type other than
# those listed here will raise an exception in our custom
# visitor class defined below.
SAFE_NODES = set(
(
ast.Expression,
ast.Compare,
ast.Str,
ast.List,
ast.Tuple,
ast.Dict,
ast.Call,
ast.Load,
ast.BinOp,
ast.UnaryOp,
ast.Num,
ast.Name,
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
)
)
def is_unset(var): # AST node types were expanded after 2.6
return var.startswith("$") or '{{' in var if not sys.version.startswith('2.6'):
SAFE_NODES.union(
set(
(ast.Set,)
)
)
# do not allow method calls to modules # builtin functions that are not safe to call
if not isinstance(str, basestring): INVALID_CALLS = (
'classmethod', 'compile', 'delattr', 'eval', 'execfile', 'file',
'filter', 'help', 'input', 'object', 'open', 'raw_input', 'reduce',
'reload', 'repr', 'setattr', 'staticmethod', 'super', 'type',
)
class CleansingNodeVisitor(ast.NodeVisitor):
def generic_visit(self, node):
if type(node) not in SAFE_NODES:
#raise Exception("invalid expression (%s) type=%s" % (expr, type(node)))
raise Exception("invalid expression (%s)" % expr)
super(CleansingNodeVisitor, self).generic_visit(node)
def visit_Call(self, call):
if call.func.id in INVALID_CALLS:
raise Exception("invalid function: %s" % call.func.id)
if not isinstance(expr, basestring):
# already templated to a datastructure, perhaps? # already templated to a datastructure, perhaps?
if include_exceptions: if include_exceptions:
return (str, None) return (expr, None)
return str return expr
if re.search(r'\w\.\w+\(', str):
if include_exceptions:
return (str, None)
return str
# do not allow imports
if re.search(r'import \w+', str):
if include_exceptions:
return (str, None)
return str
try: try:
result = None parsed_tree = ast.parse(expr, mode='eval')
if not locals: cnv = CleansingNodeVisitor()
result = eval(str) cnv.visit(parsed_tree)
else: compiled = compile(parsed_tree, expr, 'eval')
result = eval(str, None, locals) result = eval(compiled, {}, locals)
if include_exceptions: if include_exceptions:
return (result, None) return (result, None)
else: else:
return result return result
except SyntaxError, e:
# special handling for syntax errors, we just return
# the expression string back as-is
if include_exceptions:
return (expr, None)
return expr
except Exception, e: except Exception, e:
if include_exceptions: if include_exceptions:
return (str, e) return (expr, e)
return str return expr
def listify_lookup_plugin_terms(terms, basedir, inject): def listify_lookup_plugin_terms(terms, basedir, inject):
@ -1030,7 +1075,7 @@ def listify_lookup_plugin_terms(terms, basedir, inject):
# with_items: {{ alist }} # with_items: {{ alist }}
stripped = terms.strip() stripped = terms.strip()
if not (stripped.startswith('{') or stripped.startswith('[')) and not stripped.startswith("/"): if not (stripped.startswith('{') or stripped.startswith('[')) and not stripped.startswith("/") and not stripped.startswith('set(['):
# if not already a list, get ready to evaluate with Jinja2 # if not already a list, get ready to evaluate with Jinja2
# not sure why the "/" is in above code :) # not sure why the "/" is in above code :)
try: try: