Migrate most uses of if type() to if isinstance()

Also convert those checks to use abcs instead of dict and list.

Make a sentinel class for strategies to report when they've reache the end
This commit is contained in:
Toshio Kuratomi 2017-03-26 09:24:30 -07:00
parent 64fe7402ff
commit 6bad4e57bd
12 changed files with 49 additions and 34 deletions

View file

@ -1378,7 +1378,7 @@ class Ec2Inventory(object):
elif key == 'ec2__previous_state': elif key == 'ec2__previous_state':
instance_vars['ec2_previous_state'] = instance.previous_state or '' instance_vars['ec2_previous_state'] = instance.previous_state or ''
instance_vars['ec2_previous_state_code'] = instance.previous_state_code instance_vars['ec2_previous_state_code'] = instance.previous_state_code
elif type(value) in [int, bool]: elif isinstance(value, (int, bool)):
instance_vars[key] = value instance_vars[key] = value
elif isinstance(value, six.string_types): elif isinstance(value, six.string_types):
instance_vars[key] = value.strip() instance_vars[key] = value.strip()
@ -1483,7 +1483,7 @@ class Ec2Inventory(object):
# Target: Everything # Target: Everything
# Preserve booleans and integers # Preserve booleans and integers
elif type(value) in [int, bool]: elif isinstance(value, (int, bool)):
host_info[key] = value host_info[key] = value
# Target: Everything # Target: Everything

View file

@ -385,7 +385,7 @@ class PacketInventory(object):
device_vars[key] = device.state or '' device_vars[key] = device.state or ''
elif key == 'packet_hostname': elif key == 'packet_hostname':
device_vars[key] = value device_vars[key] = value
elif type(value) in [int, bool]: elif isinstance(value, (int, bool)):
device_vars[key] = value device_vars[key] = value
elif isinstance(value, six.string_types): elif isinstance(value, six.string_types):
device_vars[key] = value.strip() device_vars[key] = value.strip()

View file

@ -43,13 +43,16 @@
import argparse import argparse
import os.path import os.path
import sys import sys
import paramiko from collections import MutableSequence
try: try:
import json import json
except ImportError: except ImportError:
import simplejson as json import simplejson as json
import paramiko
SSH_CONF = '~/.ssh/config' SSH_CONF = '~/.ssh/config'
_key = 'ssh_config' _key = 'ssh_config'
@ -68,7 +71,7 @@ def get_config():
cfg.parse(f) cfg.parse(f)
ret_dict = {} ret_dict = {}
for d in cfg._config: for d in cfg._config:
if type(d['host']) is list: if isinstance(d['host'], MutableSequence):
alias = d['host'][0] alias = d['host'][0]
else: else:
alias = d['host'] alias = d['host']
@ -93,7 +96,7 @@ def print_list():
# If the attribute is a list, just take the first element. # If the attribute is a list, just take the first element.
# Private key is returned in a list for some reason. # Private key is returned in a list for some reason.
attr = attributes[ssh_opt] attr = attributes[ssh_opt]
if type(attr) is list: if isinstance(attr, MutableSequence):
attr = attr[0] attr = attr[0]
tmp_dict[ans_opt] = attr tmp_dict[ans_opt] = attr
if tmp_dict: if tmp_dict:

View file

@ -19,8 +19,7 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from ansible.plugins.callback.default import CallbackModule as CallbackModule_default from collections import MutableMapping, MutableSequence
from ansible.utils.color import colorize, hostcolor
HAS_OD = False HAS_OD = False
try: try:
@ -29,6 +28,10 @@ try:
except ImportError: except ImportError:
pass pass
from ansible.module_utils.six import binary_type, text_type
from ansible.plugins.callback.default import CallbackModule as CallbackModule_default
from ansible.utils.color import colorize, hostcolor
try: try:
from __main__ import display from __main__ import display
except ImportError: except ImportError:
@ -235,7 +238,7 @@ class CallbackModule_dense(CallbackModule_default):
# Remove empty attributes (list, dict, str) # Remove empty attributes (list, dict, str)
for attr in result.copy(): for attr in result.copy():
if type(result[attr]) in (list, dict, basestring, unicode): if isinstance(result[attr], (MutableSequence, MutableMapping, binary_type, text_type)):
if not result[attr]: if not result[attr]:
del(result[attr]) del(result[attr])

View file

@ -22,6 +22,7 @@ __metaclass__ = type
import os import os
import time import time
import json import json
from collections import MutableMapping
from ansible.module_utils._text import to_bytes from ansible.module_utils._text import to_bytes
from ansible.plugins.callback import CallbackBase from ansible.plugins.callback import CallbackBase
@ -54,7 +55,7 @@ class CallbackModule(CallbackBase):
os.makedirs("/var/log/ansible/hosts") os.makedirs("/var/log/ansible/hosts")
def log(self, host, category, data): def log(self, host, category, data):
if type(data) == dict: if isinstance(data, MutableMapping):
if '_ansible_verbose_override' in data: if '_ansible_verbose_override' in data:
# avoid logging extraneous data # avoid logging extraneous data
data = 'omitted' data = 'omitted'

View file

@ -31,6 +31,7 @@ import re
import string import string
import sys import sys
import uuid import uuid
from collections import MutableMapping, MutableSequence
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from random import Random, SystemRandom, shuffle from random import Random, SystemRandom, shuffle
@ -108,14 +109,13 @@ def to_nice_json(a, indent=4, *args, **kw):
def to_bool(a): def to_bool(a):
''' return a bool for the arg ''' ''' return a bool for the arg '''
if a is None or type(a) == bool: if a is None or isinstance(a, bool):
return a return a
if isinstance(a, string_types): if isinstance(a, string_types):
a = a.lower() a = a.lower()
if a in ['yes', 'on', '1', 'true', 1]: if a in ('yes', 'on', '1', 'true', 1):
return True return True
else: return False
return False
def to_datetime(string, format="%Y-%d-%m %H:%M:%S"): def to_datetime(string, format="%Y-%d-%m %H:%M:%S"):
return datetime.strptime(string, format) return datetime.strptime(string, format)
@ -402,10 +402,10 @@ def extract(item, container, morekeys=None):
def failed(*a, **kw): def failed(*a, **kw):
''' Test if task result yields failed ''' ''' Test if task result yields failed '''
item = a[0] item = a[0]
if type(item) != dict: if not isinstance(item, MutableMapping):
raise errors.AnsibleFilterError("|failed expects a dictionary") raise errors.AnsibleFilterError("|failed expects a dictionary")
rc = item.get('rc',0) rc = item.get('rc', 0)
failed = item.get('failed',False) failed = item.get('failed', False)
if rc != 0 or failed: if rc != 0 or failed:
return True return True
else: else:
@ -418,13 +418,13 @@ def success(*a, **kw):
def changed(*a, **kw): def changed(*a, **kw):
''' Test if task result yields changed ''' ''' Test if task result yields changed '''
item = a[0] item = a[0]
if type(item) != dict: if not isinstance(item, MutableMapping):
raise errors.AnsibleFilterError("|changed expects a dictionary") raise errors.AnsibleFilterError("|changed expects a dictionary")
if not 'changed' in item: if not 'changed' in item:
changed = False changed = False
if ('results' in item # some modules return a 'results' key if ('results' in item # some modules return a 'results' key
and type(item['results']) == list and isinstance(item['results'], MutableSequence)
and type(item['results'][0]) == dict): and isinstance(item['results'][0], MutableMapping)):
for result in item['results']: for result in item['results']:
changed = changed or result.get('changed', False) changed = changed or result.get('changed', False)
else: else:
@ -434,7 +434,7 @@ def changed(*a, **kw):
def skipped(*a, **kw): def skipped(*a, **kw):
''' Test if task result yields skipped ''' ''' Test if task result yields skipped '''
item = a[0] item = a[0]
if type(item) != dict: if not isinstance(item, MutableMapping):
raise errors.AnsibleFilterError("|skipped expects a dictionary") raise errors.AnsibleFilterError("|skipped expects a dictionary")
skipped = item.get('skipped', False) skipped = item.get('skipped', False)
return skipped return skipped

View file

@ -19,6 +19,7 @@ __metaclass__ = type
import codecs import codecs
import csv import csv
from collections import MutableSequence
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.plugins.lookup import LookupBase from ansible.plugins.lookup import LookupBase
@ -102,7 +103,7 @@ class LookupModule(LookupBase):
lookupfile = self.find_file_in_search_path(variables, 'files', paramvals['file']) lookupfile = self.find_file_in_search_path(variables, 'files', paramvals['file'])
var = self.read_csv(lookupfile, key, paramvals['delimiter'], paramvals['encoding'], paramvals['default'], paramvals['col']) var = self.read_csv(lookupfile, key, paramvals['delimiter'], paramvals['encoding'], paramvals['default'], paramvals['col'])
if var is not None: if var is not None:
if type(var) is list: if isinstance(var, MutableSequence):
for v in var: for v in var:
ret.append(v) ret.append(v)
else: else:

View file

@ -22,11 +22,12 @@ from ansible.plugins.lookup import LookupBase
import socket import socket
try: try:
import dns.exception
import dns.name
import dns.resolver import dns.resolver
import dns.reversename import dns.reversename
from dns.rdatatype import (A, AAAA, CNAME, DLV, DNAME, DNSKEY, DS, HINFO, LOC, from dns.rdatatype import (A, AAAA, CNAME, DLV, DNAME, DNSKEY, DS, HINFO, LOC,
MX, NAPTR, NS, NSEC3PARAM, PTR, RP, SOA, SPF, SRV, SSHFP, TLSA, TXT) MX, NAPTR, NS, NSEC3PARAM, PTR, RP, SOA, SPF, SRV, SSHFP, TLSA, TXT)
import dns.exception
HAVE_DNS = True HAVE_DNS = True
except ImportError: except ImportError:
HAVE_DNS = False HAVE_DNS = False
@ -70,7 +71,7 @@ def make_rdata_dict(rdata):
for f in fields: for f in fields:
val = rdata.__getattribute__(f) val = rdata.__getattribute__(f)
if type(val) == dns.name.Name: if isinstance(val, dns.name.Name):
val = dns.name.Name.to_text(val) val = dns.name.Name.to_text(val)
if rdata.rdtype == DLV and f == 'digest': if rdata.rdtype == DLV and f == 'digest':

View file

@ -17,9 +17,10 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from io import StringIO
import os import os
import re import re
from collections import MutableSequence
from io import StringIO
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.module_utils.six.moves import configparser from ansible.module_utils.six.moves import configparser
@ -110,7 +111,7 @@ class LookupModule(LookupBase):
else: else:
var = self.read_ini(path, key, paramvals['section'], paramvals['default'], paramvals['re']) var = self.read_ini(path, key, paramvals['section'], paramvals['default'], paramvals['re'])
if var is not None: if var is not None:
if type(var) is list: if isinstance(var, MutableSequence):
for v in var: for v in var:
ret.append(v) ret.append(v)
else: else:

View file

@ -53,6 +53,9 @@ except ImportError:
__all__ = ['StrategyBase'] __all__ = ['StrategyBase']
class StrategySentinel:
pass
# TODO: this should probably be in the plugins/__init__.py, with # TODO: this should probably be in the plugins/__init__.py, with
# a smarter mechanism to set all of the attributes based on # a smarter mechanism to set all of the attributes based on
# the loaders created there # the loaders created there
@ -70,12 +73,12 @@ class SharedPluginLoaderObj:
self.module_loader = module_loader self.module_loader = module_loader
_sentinel = object() _sentinel = StrategySentinel()
def results_thread_main(strategy): def results_thread_main(strategy):
while True: while True:
try: try:
result = strategy._final_q.get() result = strategy._final_q.get()
if type(result) == object: if isinstance(result, StrategySentinel):
break break
else: else:
strategy._results_lock.acquire() strategy._results_lock.acquire()

View file

@ -21,6 +21,7 @@ __metaclass__ = type
import re import re
import operator as py_operator import operator as py_operator
from collections import MutableMapping, MutableSequence
from distutils.version import LooseVersion, StrictVersion from distutils.version import LooseVersion, StrictVersion
from ansible import errors from ansible import errors
@ -28,7 +29,7 @@ from ansible import errors
def failed(*a, **kw): def failed(*a, **kw):
''' Test if task result yields failed ''' ''' Test if task result yields failed '''
item = a[0] item = a[0]
if type(item) != dict: if not isinstance(item, MutableMapping):
raise errors.AnsibleFilterError("|failed expects a dictionary") raise errors.AnsibleFilterError("|failed expects a dictionary")
rc = item.get('rc',0) rc = item.get('rc',0)
failed = item.get('failed',False) failed = item.get('failed',False)
@ -44,13 +45,13 @@ def success(*a, **kw):
def changed(*a, **kw): def changed(*a, **kw):
''' Test if task result yields changed ''' ''' Test if task result yields changed '''
item = a[0] item = a[0]
if type(item) != dict: if not isinstance(item, MutableMapping):
raise errors.AnsibleFilterError("|changed expects a dictionary") raise errors.AnsibleFilterError("|changed expects a dictionary")
if not 'changed' in item: if not 'changed' in item:
changed = False changed = False
if ('results' in item # some modules return a 'results' key if ('results' in item # some modules return a 'results' key
and type(item['results']) == list and isinstance(item['results'], MutableSequence)
and type(item['results'][0]) == dict): and isinstance(item['results'][0], MutableMapping)):
for result in item['results']: for result in item['results']:
changed = changed or result.get('changed', False) changed = changed or result.get('changed', False)
else: else:
@ -60,7 +61,7 @@ def changed(*a, **kw):
def skipped(*a, **kw): def skipped(*a, **kw):
''' Test if task result yields skipped ''' ''' Test if task result yields skipped '''
item = a[0] item = a[0]
if type(item) != dict: if not isinstance(item, MutableMapping):
raise errors.AnsibleFilterError("|skipped expects a dictionary") raise errors.AnsibleFilterError("|skipped expects a dictionary")
skipped = item.get('skipped', False) skipped = item.get('skipped', False)
return skipped return skipped

View file

@ -4,6 +4,7 @@ import yaml
import inspect import inspect
import collections import collections
from ansible.module_utils.six import string_types
from ansible.modules.cloud.openstack import os_server from ansible.modules.cloud.openstack import os_server
@ -26,7 +27,7 @@ def params_from_doc(func):
for task in cfg: for task in cfg:
for module, params in task.items(): for module, params in task.items():
for k, v in params.items(): for k, v in params.items():
if k in ['nics'] and type(v) == str: if k in ['nics'] and isinstance(v, string_types):
params[k] = [v] params[k] = [v]
task[module] = collections.defaultdict(str, task[module] = collections.defaultdict(str,
params) params)