move some of basic into common (#48078)

* move file functions into common.file

* move *_PERM_BITS and mark as private (_*_PERM_BITS)

* move get_{platform, distribution, distribution_version} get_all_subclasses and load_platform_subclass into common.sys_info

* forgot get_distribution_version, properly rename get_all_subclasses

* add common/sys_info.py to recursive finder test

* update module paths in test_platform_distribution.py

* update docstrings, _get_all_subclasses -> get_all_subclasses

* forgot to update names

* remove trailing whitespace
This commit is contained in:
Andreas Calminder 2018-12-07 19:21:11 +01:00 committed by Toshio Kuratomi
parent 3a4d476512
commit 876b637208
5 changed files with 184 additions and 119 deletions

View file

@ -156,7 +156,21 @@ from ansible.module_utils.common._collections_compat import (
Set, MutableSet, Set, MutableSet,
) )
from ansible.module_utils.common.process import get_bin_path from ansible.module_utils.common.process import get_bin_path
from ansible.module_utils.common.file import is_executable from ansible.module_utils.common.file import (
_PERM_BITS as PERM_BITS,
_EXEC_PERM_BITS as EXEC_PERM_BITS,
_DEFAULT_PERM as DEFAULT_PERM,
is_executable,
format_attributes,
get_flags_from_attributes,
)
from ansible.module_utils.common.sys_info import (
get_platform,
get_distribution,
get_distribution_version,
load_platform_subclass,
get_all_subclasses,
)
from ansible.module_utils.pycompat24 import get_exception, literal_eval from ansible.module_utils.pycompat24 import get_exception, literal_eval
from ansible.module_utils.six import ( from ansible.module_utils.six import (
PY2, PY2,
@ -249,11 +263,6 @@ MODE_OPERATOR_RE = re.compile(r'[+=-]')
USERS_RE = re.compile(r'[^ugo]') USERS_RE = re.compile(r'[^ugo]')
PERMS_RE = re.compile(r'[^rwxXstugo]') PERMS_RE = re.compile(r'[^rwxXstugo]')
PERM_BITS = 0o7777 # file mode permission bits
EXEC_PERM_BITS = 0o0111 # execute permission bits
DEFAULT_PERM = 0o0666 # default file permission bits
# Used for determining if the system is running a new enough python version # Used for determining if the system is running a new enough python version
# and should only restrict on our documented minimum versions # and should only restrict on our documented minimum versions
_PY3_MIN = sys.version_info[:2] >= (3, 5) _PY3_MIN = sys.version_info[:2] >= (3, 5)
@ -267,91 +276,6 @@ if not _PY_MIN:
sys.exit(1) sys.exit(1)
def get_platform():
''' what's the platform? example: Linux is a platform. '''
return platform.system()
def get_distribution():
''' return the distribution name '''
if platform.system() == 'Linux':
try:
supported_dists = platform._supported_dists + ('arch', 'alpine', 'devuan')
distribution = platform.linux_distribution(supported_dists=supported_dists)[0].capitalize()
if not distribution and os.path.isfile('/etc/system-release'):
distribution = platform.linux_distribution(supported_dists=['system'])[0].capitalize()
if 'Amazon' in distribution:
distribution = 'Amazon'
else:
distribution = 'OtherLinux'
except:
# FIXME: MethodMissing, I assume?
distribution = platform.dist()[0].capitalize()
else:
distribution = None
return distribution
def get_distribution_version():
''' return the distribution version '''
if platform.system() == 'Linux':
try:
distribution_version = platform.linux_distribution()[1]
if not distribution_version and os.path.isfile('/etc/system-release'):
distribution_version = platform.linux_distribution(supported_dists=['system'])[1]
except:
# FIXME: MethodMissing, I assume?
distribution_version = platform.dist()[1]
else:
distribution_version = None
return distribution_version
def get_all_subclasses(cls):
'''
used by modules like Hardware or Network fact classes to retrieve all subclasses of a given class.
__subclasses__ return only direct sub classes. This one go down into the class tree.
'''
# Retrieve direct subclasses
subclasses = cls.__subclasses__()
to_visit = list(subclasses)
# Then visit all subclasses
while to_visit:
for sc in to_visit:
# The current class is now visited, so remove it from list
to_visit.remove(sc)
# Appending all subclasses to visit and keep a reference of available class
for ssc in sc.__subclasses__():
subclasses.append(ssc)
to_visit.append(ssc)
return subclasses
def load_platform_subclass(cls, *args, **kwargs):
'''
used by modules like User to have different implementations based on detected platform. See User
module for an example.
'''
this_platform = get_platform()
distribution = get_distribution()
subclass = None
# get the most specific superclass for this platform
if distribution is not None:
for sc in get_all_subclasses(cls):
if sc.distribution is not None and sc.distribution == distribution and sc.platform == this_platform:
subclass = sc
if subclass is None:
for sc in get_all_subclasses(cls):
if sc.platform == this_platform and sc.distribution is None:
subclass = sc
if subclass is None:
subclass = cls
return super(cls, subclass).__new__(subclass)
def json_dict_unicode_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'): def json_dict_unicode_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'):
''' Recursively convert dict keys and values to byte str ''' Recursively convert dict keys and values to byte str
@ -745,22 +669,6 @@ def _lenient_lowercase(lst):
return lowered return lowered
def format_attributes(attributes):
attribute_list = []
for attr in attributes:
if attr in FILE_ATTRIBUTES:
attribute_list.append(FILE_ATTRIBUTES[attr])
return attribute_list
def get_flags_from_attributes(attributes):
flags = []
for key, attr in FILE_ATTRIBUTES.items():
if attr in attributes:
flags.append(key)
return ''.join(flags)
def _json_encode_fallback(obj): def _json_encode_fallback(obj):
if isinstance(obj, Set): if isinstance(obj, Set):
return list(obj) return list(obj)

View file

@ -27,8 +27,39 @@ except ImportError:
HAVE_SELINUX = False HAVE_SELINUX = False
class LockTimeout(Exception): FILE_ATTRIBUTES = {
pass 'A': 'noatime',
'a': 'append',
'c': 'compressed',
'C': 'nocow',
'd': 'nodump',
'D': 'dirsync',
'e': 'extents',
'E': 'encrypted',
'h': 'blocksize',
'i': 'immutable',
'I': 'indexed',
'j': 'journalled',
'N': 'inline',
's': 'zero',
'S': 'synchronous',
't': 'notail',
'T': 'blockroot',
'u': 'undelete',
'X': 'compressedraw',
'Z': 'compresseddirty',
}
# Used for parsing symbolic file perms
MODE_OPERATOR_RE = re.compile(r'[+=-]')
USERS_RE = re.compile(r'[^ugo]')
PERMS_RE = re.compile(r'[^rwxXstugo]')
_PERM_BITS = 0o7777 # file mode permission bits
_EXEC_PERM_BITS = 0o0111 # execute permission bits
_DEFAULT_PERM = 0o0666 # default file permission bits
def is_executable(path): def is_executable(path):
@ -45,6 +76,34 @@ def is_executable(path):
return ((stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) & os.stat(path)[stat.ST_MODE]) return ((stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) & os.stat(path)[stat.ST_MODE])
def format_attributes(attributes):
attribute_list = [FILE_ATTRIBUTES.get(attr) for attr in attributes if attr in FILE_ATTRIBUTES]
return attribute_list
def get_flags_from_attributes(attributes):
flags = [key for key, attr in FILE_ATTRIBUTES.items() if attr in attributes]
return ''.join(flags)
def get_file_arg_spec():
arg_spec = dict(
mode=dict(type='raw'),
owner=dict(),
group=dict(),
seuser=dict(),
serole=dict(),
selevel=dict(),
setype=dict(),
attributes=dict(aliases=['attr']),
)
return arg_spec
class LockTimeout(Exception):
pass
class FileLock: class FileLock:
''' '''
Currently FileLock is implemented via fcntl.flock on a lock file, however this Currently FileLock is implemented via fcntl.flock on a lock file, however this

View file

@ -0,0 +1,96 @@
import os
import platform
def get_platform():
'''
:rtype: NativeString
:returns: Name of the platform the module is running on
'''
return platform.system()
def get_distribution():
'''
:rtype: NativeString or None
:returns: Name of the distribution the module is running on
'''
distribution = None
additional_linux = ('alpine', 'arch', 'devuan')
supported_dists = platform._supported_dists + additional_linux
if platform.system() == 'Linux':
try:
distribution = platform.linux_distribution(supported_dists=supported_dists)[0].capitalize()
if not distribution and os.path.isfile('/etc/system-release'):
distribution = platform.linux_distribution(supported_dists=['system'])[0].capitalize()
if 'Amazon' in distribution:
distribution = 'Amazon'
else:
distribution = 'OtherLinux'
except:
# FIXME: MethodMissing, I assume?
distribution = platform.dist()[0].capitalize()
return distribution
def get_distribution_version():
'''
:rtype: NativeString or None
:returns: A string representation of the version of the distribution
'''
distribution_version = None
if platform.system() == 'Linux':
try:
distribution_version = platform.linux_distribution()[1]
if not distribution_version and os.path.isfile('/etc/system-release'):
distribution_version = platform.linux_distribution(supported_dists=['system'])[1]
except Exception:
# FIXME: MethodMissing, I assume?
distribution_version = platform.dist()[1]
return distribution_version
def get_all_subclasses(cls):
'''
used by modules like Hardware or Network fact classes to recursively retrieve all
subclasses of a given class not only the direct sub classes.
'''
# Retrieve direct subclasses
subclasses = cls.__subclasses__()
to_visit = list(subclasses)
# Then visit all subclasses
while to_visit:
for sc in to_visit:
# The current class is now visited, so remove it from list
to_visit.remove(sc)
# Appending all subclasses to visit and keep a reference of available class
for ssc in sc.__subclasses__():
subclasses.append(ssc)
to_visit.append(ssc)
return subclasses
def load_platform_subclass(cls, *args, **kwargs):
'''
used by modules like User to have different implementations based on detected platform. See User
module for an example.
'''
this_platform = get_platform()
distribution = get_distribution()
subclass = None
# get the most specific superclass for this platform
if distribution is not None:
for sc in get_all_subclasses(cls):
if sc.distribution is not None and sc.distribution == distribution and sc.platform == this_platform:
subclass = sc
if subclass is None:
for sc in get_all_subclasses(cls):
if sc.platform == this_platform and sc.distribution is None:
subclass = sc
if subclass is None:
subclass = cls
return super(cls, subclass).__new__(subclass)

View file

@ -46,6 +46,7 @@ MODULE_UTILS_BASIC_IMPORTS = frozenset((('_text',),
('common', '_collections_compat'), ('common', '_collections_compat'),
('common', 'file'), ('common', 'file'),
('common', 'process'), ('common', 'process'),
('common', 'sys_info'),
('parsing', '__init__'), ('parsing', '__init__'),
('parsing', 'convert_bool'), ('parsing', 'convert_bool'),
('pycompat24',), ('pycompat24',),
@ -61,6 +62,7 @@ MODULE_UTILS_BASIC_FILES = frozenset(('ansible/module_utils/parsing/__init__.py'
'ansible/module_utils/parsing/convert_bool.py', 'ansible/module_utils/parsing/convert_bool.py',
'ansible/module_utils/common/__init__.py', 'ansible/module_utils/common/__init__.py',
'ansible/module_utils/common/file.py', 'ansible/module_utils/common/file.py',
'ansible/module_utils/common/sys_info.py',
'ansible/module_utils/pycompat24.py', 'ansible/module_utils/pycompat24.py',
)) ))

View file

@ -18,11 +18,11 @@ realimport = builtins.__import__
class TestPlatform(ModuleTestCase): class TestPlatform(ModuleTestCase):
def test_module_utils_basic_get_platform(self): def test_module_utils_basic_get_platform(self):
with patch('platform.system', return_value='foo'): with patch('platform.system', return_value='foo'):
from ansible.module_utils.basic import get_platform from ansible.module_utils.common.sys_info import get_platform
self.assertEqual(get_platform(), 'foo') self.assertEqual(get_platform(), 'foo')
def test_module_utils_basic_get_distribution(self): def test_module_utils_basic_get_distribution(self):
from ansible.module_utils.basic import get_distribution from ansible.module_utils.common.sys_info import get_distribution
with patch('platform.system', return_value='Foo'): with patch('platform.system', return_value='Foo'):
self.assertEqual(get_distribution(), None) self.assertEqual(get_distribution(), None)
@ -55,7 +55,7 @@ class TestPlatform(ModuleTestCase):
self.assertEqual(get_distribution(), "Bar") self.assertEqual(get_distribution(), "Bar")
def test_module_utils_basic_get_distribution_version(self): def test_module_utils_basic_get_distribution_version(self):
from ansible.module_utils.basic import get_distribution_version from ansible.module_utils.common.sys_info import get_distribution_version
with patch('platform.system', return_value='Foo'): with patch('platform.system', return_value='Foo'):
self.assertEqual(get_distribution_version(), None) self.assertEqual(get_distribution_version(), None)
@ -90,19 +90,19 @@ class TestPlatform(ModuleTestCase):
platform = "Linux" platform = "Linux"
distribution = "Bar" distribution = "Bar"
from ansible.module_utils.basic import load_platform_subclass from ansible.module_utils.common.sys_info import load_platform_subclass
# match just the platform class, not a specific distribution # match just the platform class, not a specific distribution
with patch('ansible.module_utils.basic.get_platform', return_value="Linux"): with patch('ansible.module_utils.common.sys_info.get_platform', return_value="Linux"):
with patch('ansible.module_utils.basic.get_distribution', return_value=None): with patch('ansible.module_utils.common.sys_info.get_distribution', return_value=None):
self.assertIs(type(load_platform_subclass(LinuxTest)), Foo) self.assertIs(type(load_platform_subclass(LinuxTest)), Foo)
# match both the distribution and platform class # match both the distribution and platform class
with patch('ansible.module_utils.basic.get_platform', return_value="Linux"): with patch('ansible.module_utils.common.sys_info.get_platform', return_value="Linux"):
with patch('ansible.module_utils.basic.get_distribution', return_value="Bar"): with patch('ansible.module_utils.common.sys_info.get_distribution', return_value="Bar"):
self.assertIs(type(load_platform_subclass(LinuxTest)), Bar) self.assertIs(type(load_platform_subclass(LinuxTest)), Bar)
# if neither match, the fallback should be the top-level class # if neither match, the fallback should be the top-level class
with patch('ansible.module_utils.basic.get_platform', return_value="Foo"): with patch('ansible.module_utils.common.sys_info.get_platform', return_value="Foo"):
with patch('ansible.module_utils.basic.get_distribution', return_value=None): with patch('ansible.module_utils.common.sys_info.get_distribution', return_value=None):
self.assertIs(type(load_platform_subclass(LinuxTest)), LinuxTest) self.assertIs(type(load_platform_subclass(LinuxTest)), LinuxTest)