diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index 43a79dd3d35..be9fedfb5dc 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -165,11 +165,9 @@ from ansible.module_utils.common.file import ( get_flags_from_attributes, ) from ansible.module_utils.common.sys_info import ( - get_platform, get_distribution, get_distribution_version, - load_platform_subclass, - get_all_subclasses, + get_platform_subclass, ) from ansible.module_utils.pycompat24 import get_exception, literal_eval from ansible.module_utils.six import ( @@ -184,6 +182,7 @@ from ansible.module_utils.six import ( ) from ansible.module_utils.six.moves import map, reduce, shlex_quote from ansible.module_utils._text import to_native, to_bytes, to_text +from ansible.module_utils.common._utils import get_all_subclasses as _get_all_subclasses from ansible.module_utils.parsing.convert_bool import BOOLEANS, BOOLEANS_FALSE, BOOLEANS_TRUE, boolean @@ -276,6 +275,42 @@ if not _PY_MIN: sys.exit(1) +# +# Deprecated functions +# + +def get_platform(): + ''' + **Deprecated** Use :py:func:`platform.system` directly. + + :returns: Name of the platform the module is running on in a native string + + Returns a native string that labels the platform ("Linux", "Solaris", etc). Currently, this is + the result of calling :py:func:`platform.system`. + ''' + return platform.system() + +# End deprecated functions + + +# +# Compat shims +# + +def load_platform_subclass(cls, *args, **kwargs): + """**Deprecated**: Use ansible.module_utils.common.sys_info.get_platform_subclass instead""" + platform_cls = get_platform_subclass(cls) + return super(cls, platform_cls).__new__(platform_cls) + + +def get_all_subclasses(cls): + """**Deprecated**: Use ansible.module_utils.common._utils.get_all_subclasses instead""" + return list(_get_all_subclasses(cls)) + + +# End compat shims + + def json_dict_unicode_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'): ''' Recursively convert dict keys and values to byte str diff --git a/lib/ansible/module_utils/common/_utils.py b/lib/ansible/module_utils/common/_utils.py new file mode 100644 index 00000000000..66df3167771 --- /dev/null +++ b/lib/ansible/module_utils/common/_utils.py @@ -0,0 +1,40 @@ +# Copyright (c) 2018, Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +""" +Modules in _utils are waiting to find a better home. If you need to use them, be prepared for them +to move to a different location in the future. +""" + + +def get_all_subclasses(cls): + ''' + Recursively search and find all subclasses of a given class + + :arg cls: A python class + :rtype: set + :returns: The set of python classes which are the subclasses of `cls`. + + In python, you can use a class's :py:meth:`__subclasses__` method to determine what subclasses + of a class exist. However, `__subclasses__` only goes one level deep. This function searches + each child class's `__subclasses__` method to find all of the descendent classes. It then + returns an iterable of the descendent classes. + ''' + # Retrieve direct subclasses + subclasses = set(cls.__subclasses__()) + to_visit = list(subclasses) + # Then visit all subclasses + while to_visit: + for sc in to_visit: + # The current class is now visited, so remove it from list + to_visit.remove(sc) + # Appending all subclasses to visit and keep a reference of available class + for ssc in sc.__subclasses__(): + if ssc not in subclasses: + to_visit.append(ssc) + subclasses.add(ssc) + return subclasses diff --git a/lib/ansible/module_utils/common/sys_info.py b/lib/ansible/module_utils/common/sys_info.py index d6782959a8c..0e4dd3c8799 100644 --- a/lib/ansible/module_utils/common/sys_info.py +++ b/lib/ansible/module_utils/common/sys_info.py @@ -9,21 +9,22 @@ import os import platform from ansible.module_utils import distro +from ansible.module_utils.common._utils import get_all_subclasses -# Backwards compat. New code should just use platform.system() -def get_platform(): - ''' - :rtype: NativeString - :returns: Name of the platform the module is running on - ''' - return platform.system() +__all__ = ('get_distribution', 'get_distribution_version', 'get_platform_subclass') def get_distribution(): ''' + Return the name of the distribution the module is running on + :rtype: NativeString or None :returns: Name of the distribution the module is running on + + This function attempts to determine what Linux distribution the code is running on and return + a string representing that value. If the distribution cannot be determined, it returns + ``OtherLinux``. If not run on Linux it returns None. ''' distribution = None @@ -42,9 +43,11 @@ def get_distribution(): def get_distribution_version(): ''' + Get the version of the Linux distribution the code is running on + :rtype: NativeString or None - :returns: A string representation of the version of the distribution. None if this is not run - on a Linux machine + :returns: A string representation of the version of the distribution. If it cannot determine + the version, it returns empty string. If this is not run on a Linux machine it returns None ''' version = None if platform.system() == 'Linux': @@ -82,33 +85,36 @@ def get_distribution_codename(): return codename -def get_all_subclasses(cls): +def get_platform_subclass(cls): ''' - used by modules like Hardware or Network fact classes to recursively retrieve all - subclasses of a given class not only the direct sub classes. - ''' - # Retrieve direct subclasses - subclasses = cls.__subclasses__() - to_visit = list(subclasses) - # Then visit all subclasses - while to_visit: - for sc in to_visit: - # The current class is now visited, so remove it from list - to_visit.remove(sc) - # Appending all subclasses to visit and keep a reference of available class - for ssc in sc.__subclasses__(): - subclasses.append(ssc) - to_visit.append(ssc) - return subclasses + Finds a subclass implementing desired functionality on the platform the code is running on + :arg cls: Class to find an appropriate subclass for + :returns: A class that implements the functionality on this platform -def load_platform_subclass(cls, *args, **kwargs): - ''' - used by modules like User to have different implementations based on detected platform. See User - module for an example. + Some Ansible modules have different implementations depending on the platform they run on. This + function is used to select between the various implementations and choose one. You can look at + the implementation of the Ansible :ref:`User module` module for an example of how to use this. + + This function replaces ``basic.load_platform_subclass()``. When you port code, you need to + change the callers to be explicit about instantiating the class. For instance, code in the + Ansible User module changed from:: + + .. code-block:: python + + # Old + class User: + def __new__(cls, args, kwargs): + return load_platform_subclass(User, args, kwargs) + + # New + class User: + def __new__(cls, args, kwargs): + new_cls = get_platform_subclass(User) + return super(cls, new_cls).__new__(new_cls, args, kwargs) ''' - this_platform = get_platform() + this_platform = platform.system() distribution = get_distribution() subclass = None @@ -124,4 +130,4 @@ def load_platform_subclass(cls, *args, **kwargs): if subclass is None: subclass = cls - return super(cls, subclass).__new__(subclass) + return subclass diff --git a/test/units/executor/module_common/test_recursive_finder.py b/test/units/executor/module_common/test_recursive_finder.py index f50c821cba3..2dc74fd40eb 100644 --- a/test/units/executor/module_common/test_recursive_finder.py +++ b/test/units/executor/module_common/test_recursive_finder.py @@ -47,6 +47,7 @@ MODULE_UTILS_BASIC_IMPORTS = frozenset((('_text',), ('common', 'file'), ('common', 'process'), ('common', 'sys_info'), + ('common', '_utils'), ('distro', '__init__'), ('distro', '_distro'), ('parsing', '__init__'), @@ -55,19 +56,20 @@ MODULE_UTILS_BASIC_IMPORTS = frozenset((('_text',), ('six', '__init__'), )) -MODULE_UTILS_BASIC_FILES = frozenset(('ansible/module_utils/parsing/__init__.py', - 'ansible/module_utils/common/process.py', +MODULE_UTILS_BASIC_FILES = frozenset(('ansible/module_utils/_text.py', 'ansible/module_utils/basic.py', - 'ansible/module_utils/six/__init__.py', - 'ansible/module_utils/_text.py', - 'ansible/module_utils/common/_collections_compat.py', - 'ansible/module_utils/parsing/convert_bool.py', 'ansible/module_utils/common/__init__.py', + 'ansible/module_utils/common/_collections_compat.py', 'ansible/module_utils/common/file.py', + 'ansible/module_utils/common/process.py', 'ansible/module_utils/common/sys_info.py', + 'ansible/module_utils/common/_utils.py', 'ansible/module_utils/distro/__init__.py', 'ansible/module_utils/distro/_distro.py', + 'ansible/module_utils/parsing/__init__.py', + 'ansible/module_utils/parsing/convert_bool.py', 'ansible/module_utils/pycompat24.py', + 'ansible/module_utils/six/__init__.py', )) ONLY_BASIC_IMPORT = frozenset((('basic',),)) diff --git a/test/units/module_utils/basic/test_platform_distribution.py b/test/units/module_utils/basic/test_platform_distribution.py index e7067b1a762..e62f33f4562 100644 --- a/test/units/module_utils/basic/test_platform_distribution.py +++ b/test/units/module_utils/basic/test_platform_distribution.py @@ -16,11 +16,11 @@ from units.compat.mock import patch from ansible.module_utils.six.moves import builtins # Functions being tested -from ansible.module_utils.common.sys_info import get_all_subclasses -from ansible.module_utils.common.sys_info import get_distribution -from ansible.module_utils.common.sys_info import get_distribution_version -from ansible.module_utils.common.sys_info import get_platform -from ansible.module_utils.common.sys_info import load_platform_subclass +from ansible.module_utils.basic import get_platform +from ansible.module_utils.basic import get_all_subclasses +from ansible.module_utils.basic import get_distribution +from ansible.module_utils.basic import get_distribution_version +from ansible.module_utils.basic import load_platform_subclass realimport = builtins.__import__ @@ -104,7 +104,7 @@ class TestLoadPlatformSubclass: def test_not_linux(self): # if neither match, the fallback should be the top-level class - with patch('ansible.module_utils.common.sys_info.get_platform', return_value="Foo"): + with patch('platform.system', return_value="Foo"): with patch('ansible.module_utils.common.sys_info.get_distribution', return_value=None): assert isinstance(load_platform_subclass(self.LinuxTest), self.LinuxTest) diff --git a/test/units/module_utils/common/test_sys_info.py b/test/units/module_utils/common/test_sys_info.py new file mode 100644 index 00000000000..ba2d7e92dcb --- /dev/null +++ b/test/units/module_utils/common/test_sys_info.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +# (c) 2012-2014, Michael DeHaan +# (c) 2016 Toshio Kuratomi +# (c) 2017-2018 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pytest + +from units.compat.mock import patch + +from ansible.module_utils.six.moves import builtins + +# Functions being tested +from ansible.module_utils.common.sys_info import get_distribution +from ansible.module_utils.common.sys_info import get_distribution_version +from ansible.module_utils.common.sys_info import get_platform_subclass + + +realimport = builtins.__import__ + + +@pytest.fixture +def platform_linux(mocker): + mocker.patch('platform.system', return_value='Linux') + + +# +# get_distribution tests +# + +def test_get_distribution_not_linux(): + """If it's not Linux, then it has no distribution""" + with patch('platform.system', return_value='Foo'): + assert get_distribution() is None + + +@pytest.mark.usefixtures("platform_linux") +class TestGetDistribution: + """ Tests for get_distribution that have to find somethine""" + def test_distro_known(self): + with patch('ansible.module_utils.distro.name', return_value="foo"): + assert get_distribution() == "Foo" + + def test_distro_unknown(self): + with patch('ansible.module_utils.distro.name', return_value=""): + assert get_distribution() == "OtherLinux" + + def test_distro_amazon_part_of_another_name(self): + with patch('ansible.module_utils.distro.name', return_value="AmazonFooBar"): + assert get_distribution() == "Amazonfoobar" + + def test_distro_amazon_linux(self): + with patch('ansible.module_utils.distro.name', return_value="Amazon Linux AMI"): + assert get_distribution() == "Amazon" + + +# +# get_distribution_version tests +# + +def test_get_distribution_version_not_linux(): + """If it's not Linux, then it has no distribution""" + with patch('platform.system', return_value='Foo'): + assert get_distribution_version() is None + + +@pytest.mark.usefixtures("platform_linux") +def test_distro_found(): + with patch('ansible.module_utils.distro.version', return_value="1"): + assert get_distribution_version() == "1" + + +# +# Tests for get_platform_subclass +# + +class TestGetPlatformSubclass: + class LinuxTest: + pass + + class Foo(LinuxTest): + platform = "Linux" + distribution = None + + class Bar(LinuxTest): + platform = "Linux" + distribution = "Bar" + + def test_not_linux(self): + # if neither match, the fallback should be the top-level class + with patch('platform.system', return_value="Foo"): + with patch('ansible.module_utils.common.sys_info.get_distribution', return_value=None): + assert get_platform_subclass(self.LinuxTest) is self.LinuxTest + + @pytest.mark.usefixtures("platform_linux") + def test_get_distribution_none(self): + # match just the platform class, not a specific distribution + with patch('ansible.module_utils.common.sys_info.get_distribution', return_value=None): + assert get_platform_subclass(self.LinuxTest) is self.Foo + + @pytest.mark.usefixtures("platform_linux") + def test_get_distribution_found(self): + # match both the distribution and platform class + with patch('ansible.module_utils.common.sys_info.get_distribution', return_value="Bar"): + assert get_platform_subclass(self.LinuxTest) is self.Bar diff --git a/test/units/module_utils/common/test_utils.py b/test/units/module_utils/common/test_utils.py new file mode 100644 index 00000000000..ef952393a98 --- /dev/null +++ b/test/units/module_utils/common/test_utils.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# (c) 2018 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from ansible.module_utils.common.sys_info import get_all_subclasses + + +# +# Tests for get_all_subclasses +# + +class TestGetAllSubclasses: + class Base: + pass + + class BranchI(Base): + pass + + class BranchII(Base): + pass + + class BranchIA(BranchI): + pass + + class BranchIB(BranchI): + pass + + class BranchIIA(BranchII): + pass + + class BranchIIB(BranchII): + pass + + def test_bottom_level(self): + assert get_all_subclasses(self.BranchIIB) == set() + + def test_one_inheritance(self): + assert set(get_all_subclasses(self.BranchII)) == set([self.BranchIIA, self.BranchIIB]) + + def test_toplevel(self): + assert set(get_all_subclasses(self.Base)) == set([self.BranchI, self.BranchII, + self.BranchIA, self.BranchIB, + self.BranchIIA, self.BranchIIB])