Remove uses of assert in production code (#32079)

* Remove uses of assert in production code

* Fix assertion

* Add code smell test for assertions, currently limited to lib/ansible

* Fix assertion

* Add docs for no-assert

* Remove new assert from enos

* Fix assert in module_utils.connection
This commit is contained in:
Matt Martz 2017-11-13 10:51:18 -06:00 committed by ansibot
parent 464ded80f5
commit 99d4f5bab4
38 changed files with 195 additions and 89 deletions

View file

@ -0,0 +1,16 @@
Sanity Tests » no-assert
========================
Do not use ``assert`` in production Ansible python code. When running Python
with optimizations, Python will remove ``assert`` statements, potentially
allowing for unexpected behavior throughout the Ansible code base.
Instead of using ``assert`` you should utilize simple ``if`` statements,
that result in raising an exception. There is a new exception called
``AnsibleAssertionError`` that inherits from ``AnsibleError`` and
``AssertionError``. When possible, utilize a more specific exception
than ``AnsibleAssertionError``.
Modules will not have access to ``AnsibleAssertionError`` and should instead
raise ``AssertionError``, a more specific exception, or just use
``module.fail_json`` at the failure point.

View file

@ -172,6 +172,11 @@ class AnsibleError(Exception):
return error_message return error_message
class AnsibleAssertionError(AnsibleError, AssertionError):
'''Invalid assertion'''
pass
class AnsibleOptionsError(AnsibleError): class AnsibleOptionsError(AnsibleError):
''' bad or incomplete options passed ''' ''' bad or incomplete options passed '''
pass pass

View file

@ -98,7 +98,8 @@ def get_connection(module):
def to_commands(module, commands): def to_commands(module, commands):
assert isinstance(commands, list), 'argument must be of type <list>' if not isinstance(commands, list):
raise AssertionError('argument must be of type <list>')
transform = EntityCollection(module, command_spec) transform = EntityCollection(module, command_spec)
commands = transform(commands) commands = transform(commands)

View file

@ -2248,7 +2248,8 @@ class AnsibleModule(object):
def fail_json(self, **kwargs): def fail_json(self, **kwargs):
''' return from the module, with an error message ''' ''' return from the module, with an error message '''
assert 'msg' in kwargs, "implementation error -- msg to explain the error is required" if 'msg' not in kwargs:
raise AssertionError("implementation error -- msg to explain the error is required")
kwargs['failed'] = True kwargs['failed'] = True
# add traceback if debug or high verbosity and it is missing # add traceback if debug or high verbosity and it is missing

View file

@ -95,7 +95,8 @@ class ConnectionError(Exception):
class Connection: class Connection:
def __init__(self, socket_path): def __init__(self, socket_path):
assert socket_path is not None, 'socket_path must be a value' if socket_path is None:
raise AssertionError('socket_path must be a value')
self.socket_path = socket_path self.socket_path = socket_path
def __getattr__(self, name): def __getattr__(self, name):

View file

@ -115,7 +115,8 @@ def get_config(module, flags=None):
def to_commands(module, commands): def to_commands(module, commands):
assert isinstance(commands, list), 'argument must be of type <list>' if not isinstance(commands, list):
raise AssertionError('argument must be of type <list>')
transform = EntityCollection(module, command_spec) transform = EntityCollection(module, command_spec)
commands = transform(commands) commands = transform(commands)

View file

@ -67,7 +67,8 @@ def get_connection(module):
def to_commands(module, commands): def to_commands(module, commands):
assert isinstance(commands, list), 'argument must be of type <list>' if not isinstance(commands, list):
raise AssertionError('argument must be of type <list>')
transform = EntityCollection(module, command_spec) transform = EntityCollection(module, command_spec)
commands = transform(commands) commands = transform(commands)

View file

@ -97,7 +97,8 @@ class ConfigLine(object):
return len(self._parents) > 0 return len(self._parents) > 0
def add_child(self, obj): def add_child(self, obj):
assert isinstance(obj, ConfigLine), 'child must be of type `ConfigLine`' if not isinstance(obj, ConfigLine):
raise AssertionError('child must be of type `ConfigLine`')
self._children.append(obj) self._children.append(obj)
@ -263,7 +264,8 @@ class NetworkConfig(object):
return item return item
def get_block(self, path): def get_block(self, path):
assert isinstance(path, list), 'path argument must be a list object' if not isinstance(path, list):
raise AssertionError('path argument must be a list object')
obj = self.get_object(path) obj = self.get_object(path)
if not obj: if not obj:
raise ValueError('path does not exist in config') raise ValueError('path does not exist in config')

View file

@ -222,8 +222,10 @@ def dict_diff(base, comparable):
:returns: new dict object with differences :returns: new dict object with differences
""" """
assert isinstance(base, dict), "`base` must be of type <dict>" if not isinstance(base, dict):
assert isinstance(comparable, dict), "`comparable` must be of type <dict>" raise AssertionError("`base` must be of type <dict>")
if not isinstance(comparable, dict):
raise AssertionError("`comparable` must be of type <dict>")
updates = dict() updates = dict()
@ -257,8 +259,10 @@ def dict_merge(base, other):
:returns: new combined dict object :returns: new combined dict object
""" """
assert isinstance(base, dict), "`base` must be of type <dict>" if not isinstance(base, dict):
assert isinstance(other, dict), "`other` must be of type <dict>" raise AssertionError("`base` must be of type <dict>")
if not isinstance(other, dict):
raise AssertionError("`other` must be of type <dict>")
combined = dict() combined = dict()
@ -306,7 +310,8 @@ def conditional(expr, val, cast=None):
op, arg = match.groups() op, arg = match.groups()
else: else:
op = 'eq' op = 'eq'
assert (' ' not in str(expr)), 'invalid expression: cannot contain spaces' if ' ' in str(expr):
raise AssertionError('invalid expression: cannot contain spaces')
arg = expr arg = expr
if cast is None and val is not None: if cast is None and val is not None:

View file

@ -273,7 +273,8 @@ def umc_module_for_edit(module, object_dn, superordinate=None):
def create_containers_and_parents(container_dn): def create_containers_and_parents(container_dn):
"""Create a container and if needed the parents containers""" """Create a container and if needed the parents containers"""
import univention.admin.uexceptions as uexcp import univention.admin.uexceptions as uexcp
assert container_dn.startswith("cn=") if not container_dn.startswith("cn="):
raise AssertionError()
try: try:
parent = ldap_dn_tree_parent(container_dn) parent = ldap_dn_tree_parent(container_dn)
obj = umc_module_for_add( obj = umc_module_for_add(

View file

@ -285,7 +285,8 @@ def check_dp_status(client, dp_id, status):
:returns: True or False :returns: True or False
""" """
assert isinstance(status, list) if not isinstance(status, list):
raise AssertionError()
if pipeline_field(client, dp_id, field="@pipelineState") in status: if pipeline_field(client, dp_id, field="@pipelineState") in status:
return True return True
else: else:

View file

@ -380,7 +380,8 @@ class ClcGroup(object):
changed: Boolean- whether a change was made, changed: Boolean- whether a change was made,
group: A clc group object for the group group: A clc group object for the group
""" """
assert self.root_group, "Implementation Error: Root Group not set" if not self.root_group:
raise AssertionError("Implementation Error: Root Group not set")
parent = parent_name if parent_name is not None else self.root_group.name parent = parent_name if parent_name is not None else self.root_group.name
description = group_description description = group_description
changed = False changed = False

View file

@ -237,7 +237,8 @@ class Droplet(JsonfyMixIn):
self.update_attr(json) self.update_attr(json)
def power_on(self): def power_on(self):
assert self.status == 'off', 'Can only power on a closed one.' if self.status != 'off':
raise AssertionError('Can only power on a closed one.')
json = self.manager.power_on_droplet(self.id) json = self.manager.power_on_droplet(self.id)
self.update_attr(json) self.update_attr(json)

View file

@ -424,8 +424,10 @@ class PyVmomiDeviceHelper(object):
diskspec.device.backing.diskMode = 'persistent' diskspec.device.backing.diskMode = 'persistent'
diskspec.device.controllerKey = scsi_ctl.device.key diskspec.device.controllerKey = scsi_ctl.device.key
assert self.next_disk_unit_number != 7 if self.next_disk_unit_number == 7:
assert disk_index != 7 raise AssertionError()
if disk_index == 7:
raise AssertionError()
""" """
Configure disk unit number. Configure disk unit number.
""" """
@ -1127,7 +1129,8 @@ class PyVmomiHelper(PyVmomi):
return datastore, datastore_name return datastore, datastore_name
def obj_has_parent(self, obj, parent): def obj_has_parent(self, obj, parent):
assert obj is not None and parent is not None if obj is None and parent is None:
raise AssertionError()
current_parent = obj current_parent = obj
while True: while True:
@ -1573,7 +1576,7 @@ def main():
result["failed"] = False result["failed"] = False
else: else:
# This should not happen # This should not happen
assert False raise AssertionError()
# VM doesn't exist # VM doesn't exist
else: else:
if module.params['state'] in ['poweredon', 'poweredoff', 'present', 'restarted', 'suspended']: if module.params['state'] in ['poweredon', 'poweredoff', 'present', 'restarted', 'suspended']:

View file

@ -342,7 +342,7 @@ class PyVmomiHelper(object):
task = vm.RemoveAllSnapshots() task = vm.RemoveAllSnapshots()
else: else:
# This should not happen # This should not happen
assert False raise AssertionError()
if task: if task:
self.wait_for_task(task) self.wait_for_task(task)

View file

@ -236,8 +236,9 @@ def set_acl(consul_client, configuration):
acls_as_json = decode_acls_as_json(consul_client.acl.list()) acls_as_json = decode_acls_as_json(consul_client.acl.list())
existing_acls_mapped_by_name = dict((acl.name, acl) for acl in acls_as_json if acl.name is not None) existing_acls_mapped_by_name = dict((acl.name, acl) for acl in acls_as_json if acl.name is not None)
existing_acls_mapped_by_token = dict((acl.token, acl) for acl in acls_as_json) existing_acls_mapped_by_token = dict((acl.token, acl) for acl in acls_as_json)
assert None not in existing_acls_mapped_by_token, "expecting ACL list to be associated to a token: %s" \ if None in existing_acls_mapped_by_token:
% existing_acls_mapped_by_token[None] raise AssertionError("expecting ACL list to be associated to a token: %s" %
existing_acls_mapped_by_token[None])
if configuration.token is None and configuration.name and configuration.name in existing_acls_mapped_by_name: if configuration.token is None and configuration.name and configuration.name in existing_acls_mapped_by_name:
# No token but name given so can get token from name # No token but name given so can get token from name
@ -246,8 +247,10 @@ def set_acl(consul_client, configuration):
if configuration.token and configuration.token in existing_acls_mapped_by_token: if configuration.token and configuration.token in existing_acls_mapped_by_token:
return update_acl(consul_client, configuration) return update_acl(consul_client, configuration)
else: else:
assert configuration.token not in existing_acls_mapped_by_token if configuration.token in existing_acls_mapped_by_token:
assert configuration.name not in existing_acls_mapped_by_name raise AssertionError()
if configuration.name in existing_acls_mapped_by_name:
raise AssertionError()
return create_acl(consul_client, configuration) return create_acl(consul_client, configuration)
@ -266,7 +269,8 @@ def update_acl(consul_client, configuration):
rules_as_hcl = encode_rules_as_hcl_string(configuration.rules) rules_as_hcl = encode_rules_as_hcl_string(configuration.rules)
updated_token = consul_client.acl.update( updated_token = consul_client.acl.update(
configuration.token, name=name, type=configuration.token_type, rules=rules_as_hcl) configuration.token, name=name, type=configuration.token_type, rules=rules_as_hcl)
assert updated_token == configuration.token if updated_token != configuration.token:
raise AssertionError()
return Output(changed=changed, token=configuration.token, rules=configuration.rules, operation=UPDATE_OPERATION) return Output(changed=changed, token=configuration.token, rules=configuration.rules, operation=UPDATE_OPERATION)
@ -379,12 +383,14 @@ def encode_rules_as_json(rules):
rules_as_json = defaultdict(dict) rules_as_json = defaultdict(dict)
for rule in rules: for rule in rules:
if rule.pattern is not None: if rule.pattern is not None:
assert rule.pattern not in rules_as_json[rule.scope] if rule.pattern in rules_as_json[rule.scope]:
raise AssertionError()
rules_as_json[rule.scope][rule.pattern] = { rules_as_json[rule.scope][rule.pattern] = {
_POLICY_JSON_PROPERTY: rule.policy _POLICY_JSON_PROPERTY: rule.policy
} }
else: else:
assert rule.scope not in rules_as_json if rule.scope in rules_as_json:
raise AssertionError()
rules_as_json[rule.scope] = rule.policy rules_as_json[rule.scope] = rule.policy
return rules_as_json return rules_as_json
@ -577,7 +583,8 @@ def get_consul_client(configuration):
token = configuration.management_token token = configuration.management_token
if token is None: if token is None:
token = configuration.token token = configuration.token
assert token is not None, "Expecting the management token to always be set" if token is None:
raise AssertionError("Expecting the management token to always be set")
return consul.Consul(host=configuration.host, port=configuration.port, scheme=configuration.scheme, return consul.Consul(host=configuration.host, port=configuration.port, scheme=configuration.scheme,
verify=configuration.validate_certs, token=token) verify=configuration.validate_certs, token=token)

View file

@ -225,16 +225,14 @@ except:
# Optional, only used for XML payload # Optional, only used for XML payload
try: try:
import lxml.etree import lxml.etree # noqa
assert lxml.etree # silence pyflakes
HAS_LXML_ETREE = True HAS_LXML_ETREE = True
except ImportError: except ImportError:
HAS_LXML_ETREE = False HAS_LXML_ETREE = False
# Optional, only used for XML payload # Optional, only used for XML payload
try: try:
from xmljson import cobra from xmljson import cobra # noqa
assert cobra # silence pyflakes
HAS_XMLJSON_COBRA = True HAS_XMLJSON_COBRA = True
except ImportError: except ImportError:
HAS_XMLJSON_COBRA = False HAS_XMLJSON_COBRA = False

View file

@ -249,9 +249,7 @@ class BalancerMember(object):
balancer_member_page = fetch_url(self.module, self.management_url) balancer_member_page = fetch_url(self.module, self.management_url)
try: if balancer_member_page[1]['status'] != 200:
assert balancer_member_page[1]['status'] == 200
except AssertionError:
self.module.fail_json(msg="Could not get balancer_member_page, check for connectivity! " + balancer_member_page[1]) self.module.fail_json(msg="Could not get balancer_member_page, check for connectivity! " + balancer_member_page[1])
else: else:
try: try:
@ -296,9 +294,7 @@ class BalancerMember(object):
request_body = request_body + str(values_mapping[k]) + '=0' request_body = request_body + str(values_mapping[k]) + '=0'
response = fetch_url(self.module, self.management_url, data=str(request_body)) response = fetch_url(self.module, self.management_url, data=str(request_body))
try: if response[1]['status'] != 200:
assert response[1]['status'] == 200
except AssertionError:
self.module.fail_json(msg="Could not set the member status! " + self.host + " " + response[1]['status']) self.module.fail_json(msg="Could not set the member status! " + self.host + " " + response[1]['status'])
attributes = property(get_member_attributes) attributes = property(get_member_attributes)
@ -323,9 +319,7 @@ class Balancer(object):
def fetch_balancer_page(self): def fetch_balancer_page(self):
""" Returns the balancer management html page as a string for later parsing.""" """ Returns the balancer management html page as a string for later parsing."""
page = fetch_url(self.module, str(self.url)) page = fetch_url(self.module, str(self.url))
try: if page[1]['status'] != 200:
assert page[1]['status'] == 200
except AssertionError:
self.module.fail_json(msg="Could not get balancer page! HTTP status response: " + str(page[1]['status'])) self.module.fail_json(msg="Could not get balancer page! HTTP status response: " + str(page[1]['status']))
else: else:
content = page[0].read() content = page[0].read()
@ -343,9 +337,7 @@ class Balancer(object):
else: else:
for element in soup.findAll('a')[1::1]: for element in soup.findAll('a')[1::1]:
balancer_member_suffix = str(element.get('href')) balancer_member_suffix = str(element.get('href'))
try: if not balancer_member_suffix:
assert balancer_member_suffix is not ''
except AssertionError:
self.module.fail_json(msg="Argument 'balancer_member_suffix' is empty!") self.module.fail_json(msg="Argument 'balancer_member_suffix' is empty!")
else: else:
yield BalancerMember(str(self.base_url + balancer_member_suffix), str(self.url), self.module) yield BalancerMember(str(self.base_url + balancer_member_suffix), str(self.url), self.module)

View file

@ -19,7 +19,7 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from ansible.errors import AnsibleParserError, AnsibleError from ansible.errors import AnsibleParserError, AnsibleError, AnsibleAssertionError
from ansible.module_utils.six import iteritems, string_types from ansible.module_utils.six import iteritems, string_types
from ansible.module_utils._text import to_text from ansible.module_utils._text import to_text
from ansible.parsing.splitter import parse_kv, split_args from ansible.parsing.splitter import parse_kv, split_args
@ -98,7 +98,8 @@ class ModuleArgsParser:
def __init__(self, task_ds=None): def __init__(self, task_ds=None):
task_ds = {} if task_ds is None else task_ds task_ds = {} if task_ds is None else task_ds
assert isinstance(task_ds, dict), "the type of 'task_ds' should be a dict, but is a %s" % type(task_ds) if not isinstance(task_ds, dict):
raise AnsibleAssertionError("the type of 'task_ds' should be a dict, but is a %s" % type(task_ds))
self._task_ds = task_ds self._task_ds = task_ds
def _split_module_string(self, module_string): def _split_module_string(self, module_string):

View file

@ -72,7 +72,7 @@ try:
except ImportError: except ImportError:
pass pass
from ansible.errors import AnsibleError from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible import constants as C from ansible import constants as C
from ansible.module_utils.six import PY3, binary_type from ansible.module_utils.six import PY3, binary_type
# Note: on py2, this zip is izip not the list based zip() builtin # Note: on py2, this zip is izip not the list based zip() builtin
@ -787,7 +787,10 @@ class VaultEditor:
fh.write(data) fh.write(data)
fh.write(data[:file_len % chunk_len]) fh.write(data[:file_len % chunk_len])
assert fh.tell() == file_len # FIXME remove this assert once we have unittests to check its accuracy # FIXME remove this assert once we have unittests to check its accuracy
if fh.tell() != file_len:
raise AnsibleAssertionError()
os.fsync(fh) os.fsync(fh)
def _shred_file(self, tmp_path): def _shred_file(self, tmp_path):

View file

@ -16,7 +16,7 @@ from jinja2.exceptions import UndefinedError
from ansible import constants as C from ansible import constants as C
from ansible.module_utils.six import iteritems, string_types, with_metaclass from ansible.module_utils.six import iteritems, string_types, with_metaclass
from ansible.module_utils.parsing.convert_bool import boolean from ansible.module_utils.parsing.convert_bool import boolean
from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError
from ansible.module_utils._text import to_text, to_native from ansible.module_utils._text import to_text, to_native
from ansible.playbook.attribute import Attribute, FieldAttribute from ansible.playbook.attribute import Attribute, FieldAttribute
from ansible.parsing.dataloader import DataLoader from ansible.parsing.dataloader import DataLoader
@ -209,7 +209,8 @@ class Base(with_metaclass(BaseMeta, object)):
def load_data(self, ds, variable_manager=None, loader=None): def load_data(self, ds, variable_manager=None, loader=None):
''' walk the input datastructure and assign any values ''' ''' walk the input datastructure and assign any values '''
assert ds is not None, 'ds (%s) should not be None but it is.' % ds if ds is None:
raise AnsibleAssertionError('ds (%s) should not be None but it is.' % ds)
# cache the datastructure internally # cache the datastructure internally
setattr(self, '_ds', ds) setattr(self, '_ds', ds)
@ -547,7 +548,8 @@ class Base(with_metaclass(BaseMeta, object)):
and extended. and extended.
''' '''
assert isinstance(data, dict), 'data (%s) should be a dict but is a %s' % (data, type(data)) if not isinstance(data, dict):
raise AnsibleAssertionError('data (%s) should be a dict but is a %s' % (data, type(data)))
for (name, attribute) in iteritems(self._valid_attrs): for (name, attribute) in iteritems(self._valid_attrs):
if name in data: if name in data:

View file

@ -21,7 +21,7 @@ __metaclass__ = type
import os import os
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable, AnsibleFileNotFound from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable, AnsibleFileNotFound, AnsibleAssertionError
from ansible.module_utils.six import string_types from ansible.module_utils.six import string_types
try: try:
@ -43,7 +43,8 @@ def load_list_of_blocks(ds, play, parent_block=None, role=None, task_include=Non
from ansible.playbook.task_include import TaskInclude from ansible.playbook.task_include import TaskInclude
from ansible.playbook.role_include import IncludeRole from ansible.playbook.role_include import IncludeRole
assert isinstance(ds, (list, type(None))), '%s should be a list or None but is %s' % (ds, type(ds)) if not isinstance(ds, (list, type(None))):
raise AnsibleAssertionError('%s should be a list or None but is %s' % (ds, type(ds)))
block_list = [] block_list = []
if ds: if ds:
@ -89,11 +90,13 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h
from ansible.playbook.handler_task_include import HandlerTaskInclude from ansible.playbook.handler_task_include import HandlerTaskInclude
from ansible.template import Templar from ansible.template import Templar
assert isinstance(ds, list), 'The ds (%s) should be a list but was a %s' % (ds, type(ds)) if not isinstance(ds, list):
raise AnsibleAssertionError('The ds (%s) should be a list but was a %s' % (ds, type(ds)))
task_list = [] task_list = []
for task_ds in ds: for task_ds in ds:
assert isinstance(task_ds, dict), 'The ds (%s) should be a dict but was a %s' % (ds, type(ds)) if not isinstance(task_ds, dict):
AnsibleAssertionError('The ds (%s) should be a dict but was a %s' % (ds, type(ds)))
if 'block' in task_ds: if 'block' in task_ds:
t = Block.load( t = Block.load(
@ -345,7 +348,8 @@ def load_list_of_roles(ds, play, current_role_path=None, variable_manager=None,
# we import here to prevent a circular dependency with imports # we import here to prevent a circular dependency with imports
from ansible.playbook.role.include import RoleInclude from ansible.playbook.role.include import RoleInclude
assert isinstance(ds, list), 'ds (%s) should be a list but was a %s' % (ds, type(ds)) if not isinstance(ds, list):
raise AnsibleAssertionError('ds (%s) should be a list but was a %s' % (ds, type(ds)))
roles = [] roles = []
for role_def in ds: for role_def in ds:

View file

@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError, AnsibleAssertionError
from ansible.module_utils.six import string_types from ansible.module_utils.six import string_types
from ansible.playbook.attribute import FieldAttribute from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.base import Base from ansible.playbook.base import Base
@ -116,7 +116,8 @@ class Play(Base, Taggable, Become):
Adjusts play datastructure to cleanup old/legacy items Adjusts play datastructure to cleanup old/legacy items
''' '''
assert isinstance(ds, dict), 'while preprocessing data (%s), ds should be a dict but was a %s' % (ds, type(ds)) if not isinstance(ds, dict):
raise AnsibleAssertionError('while preprocessing data (%s), ds should be a dict but was a %s' % (ds, type(ds)))
# The use of 'user' in the Play datastructure was deprecated to # The use of 'user' in the Play datastructure was deprecated to
# line up with the same change for Tasks, due to the fact that # line up with the same change for Tasks, due to the fact that

View file

@ -21,7 +21,7 @@ __metaclass__ = type
import os import os
from ansible.errors import AnsibleParserError, AnsibleError from ansible.errors import AnsibleParserError, AnsibleError, AnsibleAssertionError
from ansible.module_utils.six import iteritems from ansible.module_utils.six import iteritems
from ansible.parsing.splitter import split_args, parse_kv from ansible.parsing.splitter import split_args, parse_kv
from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping
@ -105,7 +105,8 @@ class PlaybookInclude(Base, Conditional, Taggable):
up with what we expect the proper attributes to be up with what we expect the proper attributes to be
''' '''
assert isinstance(ds, dict), 'ds (%s) should be a dict but was a %s' % (ds, type(ds)) if not isinstance(ds, dict):
raise AnsibleAssertionError('ds (%s) should be a dict but was a %s' % (ds, type(ds)))
# the new, cleaned datastructure, which will have legacy # the new, cleaned datastructure, which will have legacy
# items reduced to a standard structure # items reduced to a standard structure

View file

@ -22,7 +22,7 @@ __metaclass__ = type
import collections import collections
import os import os
from ansible.errors import AnsibleError, AnsibleParserError from ansible.errors import AnsibleError, AnsibleParserError, AnsibleAssertionError
from ansible.module_utils.six import iteritems, binary_type, text_type from ansible.module_utils.six import iteritems, binary_type, text_type
from ansible.playbook.attribute import FieldAttribute from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.base import Base from ansible.playbook.base import Base
@ -293,7 +293,8 @@ class Role(Base, Become, Conditional, Taggable):
def add_parent(self, parent_role): def add_parent(self, parent_role):
''' adds a role to the list of this roles parents ''' ''' adds a role to the list of this roles parents '''
assert isinstance(parent_role, Role) if not isinstance(parent_role, Role):
raise AnsibleAssertionError()
if parent_role not in self._parents: if parent_role not in self._parents:
self._parents.append(parent_role) self._parents.append(parent_role)

View file

@ -22,7 +22,7 @@ __metaclass__ = type
import os import os
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils.six import iteritems, string_types from ansible.module_utils.six import iteritems, string_types
from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping
from ansible.playbook.attribute import Attribute, FieldAttribute from ansible.playbook.attribute import Attribute, FieldAttribute
@ -72,7 +72,8 @@ class RoleDefinition(Base, Become, Conditional, Taggable):
if isinstance(ds, int): if isinstance(ds, int):
ds = "%s" % ds ds = "%s" % ds
assert isinstance(ds, dict) or isinstance(ds, string_types) or isinstance(ds, AnsibleBaseYAMLObject) if not isinstance(ds, dict) and not isinstance(ds, string_types) and not isinstance(ds, AnsibleBaseYAMLObject):
raise AnsibleAssertionError()
if isinstance(ds, dict): if isinstance(ds, dict):
ds = super(RoleDefinition, self).preprocess_data(ds) ds = super(RoleDefinition, self).preprocess_data(ds)

View file

@ -22,7 +22,7 @@ __metaclass__ = type
import os import os
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError
from ansible.module_utils.six import iteritems, string_types from ansible.module_utils.six import iteritems, string_types
from ansible.module_utils._text import to_native from ansible.module_utils._text import to_native
from ansible.parsing.mod_args import ModuleArgsParser from ansible.parsing.mod_args import ModuleArgsParser
@ -167,7 +167,8 @@ class Task(Base, Conditional, Taggable, Become):
keep it short. keep it short.
''' '''
assert isinstance(ds, dict), 'ds (%s) should be a dict but was a %s' % (ds, type(ds)) if not isinstance(ds, dict):
raise AnsibleAssertionError('ds (%s) should be a dict but was a %s' % (ds, type(ds)))
# the new, cleaned datastructure, which will have legacy # the new, cleaned datastructure, which will have legacy
# items reduced to a standard structure suitable for the # items reduced to a standard structure suitable for the

View file

@ -64,7 +64,7 @@ RETURN = """
import os import os
import sys import sys
from ansible.module_utils.six.moves.urllib.parse import urlparse from ansible.module_utils.six.moves.urllib.parse import urlparse
from ansible.errors import AnsibleError from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.plugins.lookup import LookupBase from ansible.plugins.lookup import LookupBase
try: try:
@ -131,7 +131,8 @@ class LookupModule(LookupBase):
for param in params[1:]: for param in params[1:]:
if param and len(param) > 0: if param and len(param) > 0:
name, value = param.split('=') name, value = param.split('=')
assert name in paramvals, "%s not a valid consul lookup parameter" % name if name not in paramvals:
raise AnsibleAssertionError("%s not a valid consul lookup parameter" % name)
paramvals[name] = value paramvals[name] = value
except (ValueError, AssertionError) as e: except (ValueError, AssertionError) as e:
raise AnsibleError(e) raise AnsibleError(e)

View file

@ -51,7 +51,7 @@ import codecs
import csv import csv
from collections import MutableSequence from collections import MutableSequence
from ansible.errors import AnsibleError from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.plugins.lookup import LookupBase from ansible.plugins.lookup import LookupBase
from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils._text import to_bytes, to_native, to_text
@ -124,7 +124,8 @@ class LookupModule(LookupBase):
try: try:
for param in params[1:]: for param in params[1:]:
name, value = param.split('=') name, value = param.split('=')
assert(name in paramvals) if name not in paramvals:
raise AnsibleAssertionError('%s not in paramvals' % name)
paramvals[name] = value paramvals[name] = value
except (ValueError, AssertionError) as e: except (ValueError, AssertionError) as e:
raise AnsibleError(e) raise AnsibleError(e)

View file

@ -65,7 +65,7 @@ import re
from collections import MutableSequence from collections import MutableSequence
from io import StringIO from io import StringIO
from ansible.errors import AnsibleError from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils.six.moves import configparser from ansible.module_utils.six.moves import configparser
from ansible.module_utils._text import to_bytes, to_text from ansible.module_utils._text import to_bytes, to_text
from ansible.plugins.lookup import LookupBase from ansible.plugins.lookup import LookupBase
@ -129,7 +129,8 @@ class LookupModule(LookupBase):
try: try:
for param in params[1:]: for param in params[1:]:
name, value = param.split('=') name, value = param.split('=')
assert(name in paramvals) if name not in paramvals:
raise AnsibleAssertionError('%s not in paramvals' % name)
paramvals[name] = value paramvals[name] = value
except (ValueError, AssertionError) as e: except (ValueError, AssertionError) as e:
raise AnsibleError(e) raise AnsibleError(e)

View file

@ -92,7 +92,7 @@ _raw:
import os import os
import string import string
from ansible.errors import AnsibleError from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.parsing.splitter import parse_kv from ansible.parsing.splitter import parse_kv
from ansible.plugins.lookup import LookupBase from ansible.plugins.lookup import LookupBase
@ -250,7 +250,8 @@ def _format_content(password, salt, encrypt=True):
return password return password
# At this point, the calling code should have assured us that there is a salt value. # At this point, the calling code should have assured us that there is a salt value.
assert salt, '_format_content was called with encryption requested but no salt value' if not salt:
raise AnsibleAssertionError('_format_content was called with encryption requested but no salt value')
return u'%s salt=%s' % (password, salt) return u'%s salt=%s' % (password, salt)

View file

@ -82,7 +82,7 @@ import subprocess
import time import time
from distutils import util from distutils import util
from ansible.errors import AnsibleError from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.utils.encrypt import random_password from ansible.utils.encrypt import random_password
from ansible.plugins.lookup import LookupBase from ansible.plugins.lookup import LookupBase
@ -138,7 +138,8 @@ class LookupModule(LookupBase):
try: try:
for param in params[1:]: for param in params[1:]:
name, value = param.split('=') name, value = param.split('=')
assert(name in self.paramvals) if name not in self.paramvals:
raise AnsibleAssertionError('%s not in paramvals' % name)
self.paramvals[name] = value self.paramvals[name] = value
except (ValueError, AssertionError) as e: except (ValueError, AssertionError) as e:
raise AnsibleError(e) raise AnsibleError(e)

View file

@ -33,7 +33,7 @@ _list:
""" """
import shelve import shelve
from ansible.errors import AnsibleError from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.plugins.lookup import LookupBase from ansible.plugins.lookup import LookupBase
from ansible.module_utils._text import to_bytes, to_text from ansible.module_utils._text import to_bytes, to_text
@ -63,7 +63,8 @@ class LookupModule(LookupBase):
try: try:
for param in params: for param in params:
name, value = param.split('=') name, value = param.split('=')
assert(name in paramvals) if name not in paramvals:
raise AnsibleAssertionError('%s not in paramvals' % name)
paramvals[name] = value paramvals[name] = value
except (ValueError, AssertionError) as e: except (ValueError, AssertionError) as e:

View file

@ -42,7 +42,7 @@ from jinja2.runtime import Context, StrictUndefined
from jinja2.utils import concat as j2_concat from jinja2.utils import concat as j2_concat
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleFilterError, AnsibleUndefinedVariable from ansible.errors import AnsibleError, AnsibleFilterError, AnsibleUndefinedVariable, AnsibleAssertionError
from ansible.module_utils.six import string_types, text_type from ansible.module_utils.six import string_types, text_type
from ansible.module_utils._text import to_native, to_text, to_bytes from ansible.module_utils._text import to_native, to_text, to_bytes
from ansible.plugins.loader import filter_loader, lookup_loader, test_loader from ansible.plugins.loader import filter_loader, lookup_loader, test_loader
@ -387,7 +387,8 @@ class Templar:
are being changed. are being changed.
''' '''
assert isinstance(variables, dict), "the type of 'variables' should be a dict but was a %s" % (type(variables)) if not isinstance(variables, dict):
raise AnsibleAssertionError("the type of 'variables' should be a dict but was a %s" % (type(variables)))
self._available_variables = variables self._available_variables = variables
self._cached_result = {} self._cached_result = {}

View file

@ -8,7 +8,7 @@ import multiprocessing
import random import random
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils.six import text_type from ansible.module_utils.six import text_type
from ansible.module_utils._text import to_text, to_bytes from ansible.module_utils._text import to_text, to_bytes
@ -67,7 +67,8 @@ def random_password(length=DEFAULT_PASSWORD_LENGTH, chars=C.DEFAULT_PASSWORD_CHA
:kwarg chars: The characters to choose from. The default is all ascii :kwarg chars: The characters to choose from. The default is all ascii
letters, ascii digits, and these symbols ``.,:-_`` letters, ascii digits, and these symbols ``.,:-_``
''' '''
assert isinstance(chars, text_type), '%s (%s) is not a text_type' % (chars, type(chars)) if not isinstance(chars, text_type):
raise AnsibleAssertionError('%s (%s) is not a text_type' % (chars, type(chars)))
random_generator = random.SystemRandom() random_generator = random.SystemRandom()
return u''.join(random_generator.choice(chars) for dummy in range(length)) return u''.join(random_generator.choice(chars) for dummy in range(length))

View file

@ -22,6 +22,7 @@ __metaclass__ = type
from collections import MutableMapping, MutableSet, MutableSequence from collections import MutableMapping, MutableSet, MutableSequence
from ansible.errors import AnsibleAssertionError
from ansible.module_utils.six import string_types from ansible.module_utils.six import string_types
from ansible.parsing.plugin_docs import read_docstring from ansible.parsing.plugin_docs import read_docstring
from ansible.parsing.yaml.loader import AnsibleLoader from ansible.parsing.yaml.loader import AnsibleLoader
@ -59,7 +60,8 @@ def add_fragments(doc, filename):
fragment_name, fragment_var = fragment_slug, 'DOCUMENTATION' fragment_name, fragment_var = fragment_slug, 'DOCUMENTATION'
fragment_class = fragment_loader.get(fragment_name) fragment_class = fragment_loader.get(fragment_name)
assert fragment_class is not None if fragment_class is None:
raise AnsibleAssertionError('fragment_class is None')
fragment_yaml = getattr(fragment_class, fragment_var, '{}') fragment_yaml = getattr(fragment_class, fragment_var, '{}')
fragment = AnsibleLoader(fragment_yaml, file_name=filename).get_single_data() fragment = AnsibleLoader(fragment_yaml, file_name=filename).get_single_data()

View file

@ -32,7 +32,7 @@ except ImportError:
from jinja2.exceptions import UndefinedError from jinja2.exceptions import UndefinedError
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleFileNotFound from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleFileNotFound, AnsibleAssertionError
from ansible.inventory.host import Host from ansible.inventory.host import Host
from ansible.inventory.helpers import sort_groups, get_group_vars from ansible.inventory.helpers import sort_groups, get_group_vars
from ansible.module_utils._text import to_native from ansible.module_utils._text import to_native
@ -132,7 +132,8 @@ class VariableManager:
@extra_vars.setter @extra_vars.setter
def extra_vars(self, value): def extra_vars(self, value):
''' ensures a clean copy of the extra_vars are used to set the value ''' ''' ensures a clean copy of the extra_vars are used to set the value '''
assert isinstance(value, MutableMapping), "the type of 'value' for extra_vars should be a MutableMapping, but is a %s" % type(value) if not isinstance(value, MutableMapping):
raise AnsibleAssertionError("the type of 'value' for extra_vars should be a MutableMapping, but is a %s" % type(value))
self._extra_vars = value.copy() self._extra_vars = value.copy()
def set_inventory(self, inventory): def set_inventory(self, inventory):
@ -146,7 +147,8 @@ class VariableManager:
@options_vars.setter @options_vars.setter
def options_vars(self, value): def options_vars(self, value):
''' ensures a clean copy of the options_vars are used to set the value ''' ''' ensures a clean copy of the options_vars are used to set the value '''
assert isinstance(value, dict), "the type of 'value' for options_vars should be a dict, but is a %s" % type(value) if not isinstance(value, dict):
raise AnsibleAssertionError("the type of 'value' for options_vars should be a dict, but is a %s" % type(value))
self._options_vars = value.copy() self._options_vars = value.copy()
def _preprocess_vars(self, a): def _preprocess_vars(self, a):
@ -592,7 +594,8 @@ class VariableManager:
Sets or updates the given facts for a host in the fact cache. Sets or updates the given facts for a host in the fact cache.
''' '''
assert isinstance(facts, dict), "the type of 'facts' to set for host_facts should be a dict but is a %s" % type(facts) if not isinstance(facts, dict):
raise AnsibleAssertionError("the type of 'facts' to set for host_facts should be a dict but is a %s" % type(facts))
if host.name not in self._fact_cache: if host.name not in self._fact_cache:
self._fact_cache[host.name] = facts self._fact_cache[host.name] = facts
@ -607,7 +610,8 @@ class VariableManager:
Sets or updates the given facts for a host in the fact cache. Sets or updates the given facts for a host in the fact cache.
''' '''
assert isinstance(facts, dict), "the type of 'facts' to set for nonpersistent_facts should be a dict but is a %s" % type(facts) if not isinstance(facts, dict):
raise AnsibleAssertionError("the type of 'facts' to set for nonpersistent_facts should be a dict but is a %s" % type(facts))
if host.name not in self._nonpersistent_fact_cache: if host.name not in self._nonpersistent_fact_cache:
self._nonpersistent_fact_cache[host.name] = facts self._nonpersistent_fact_cache[host.name] = facts

View file

@ -0,0 +1,40 @@
#!/usr/bin/env python
from __future__ import print_function
import os
import re
import sys
from collections import defaultdict
PATH = 'lib/ansible'
ASSERT_RE = re.compile(r'.*(?<![-:a-zA-Z#][ -])\bassert\b(?!:).*')
all_matches = defaultdict(list)
for dirpath, dirnames, filenames in os.walk(PATH):
for filename in filenames:
path = os.path.join(dirpath, filename)
if not os.path.isfile(path) or not path.endswith('.py'):
continue
with open(path, 'r') as f:
for i, line in enumerate(f.readlines()):
matches = ASSERT_RE.findall(line)
if matches:
all_matches[path].append((i + 1, line.index('assert') + 1, matches))
if all_matches:
print('Use of assert in production code is not recommended.')
print('Python will remove all assert statements if run with optimizations')
print('Alternatives:')
print(' if not isinstance(value, dict):')
print(' raise AssertionError("Expected a dict for value")')
for path, matches in all_matches.items():
for line_matches in matches:
for match in line_matches[2]:
print('%s:%d:%d: %s' % ((path,) + line_matches[:2] + (match,)))
sys.exit(1)