From 7641d32f8ec0f18722ded8b2bea253a705d3917c Mon Sep 17 00:00:00 2001
From: Matt Clay <mclay@redhat.com>
Date: Mon, 15 Jun 2020 15:09:15 -0700
Subject: [PATCH] Fix ansible-test import sanity test issues. (#70084)

---
 .../ansible-test-sanity-import-fixes.yml      |  3 ++
 .../_data/sanity/import/importer.py           | 29 ++++++++++++++-----
 2 files changed, 24 insertions(+), 8 deletions(-)
 create mode 100644 changelogs/fragments/ansible-test-sanity-import-fixes.yml

diff --git a/changelogs/fragments/ansible-test-sanity-import-fixes.yml b/changelogs/fragments/ansible-test-sanity-import-fixes.yml
new file mode 100644
index 00000000000..9f7f010c90b
--- /dev/null
+++ b/changelogs/fragments/ansible-test-sanity-import-fixes.yml
@@ -0,0 +1,3 @@
+bugfixes:
+    - ansible-test - The ``import`` sanity test now correctly blocks access to python modules, not just packages, in the ``ansible`` package.
+    - ansible-test - The ``import`` sanity test now correctly provides an empty ``ansible`` package.
diff --git a/test/lib/ansible_test/_data/sanity/import/importer.py b/test/lib/ansible_test/_data/sanity/import/importer.py
index 21c19a064b7..ef8db71b518 100755
--- a/test/lib/ansible_test/_data/sanity/import/importer.py
+++ b/test/lib/ansible_test/_data/sanity/import/importer.py
@@ -19,6 +19,7 @@ def main():
     import subprocess
     import sys
     import traceback
+    import types
     import warnings
 
     ansible_path = os.path.dirname(os.path.dirname(ansible.__file__))
@@ -94,14 +95,23 @@ def main():
 
         collection_loader = _AnsibleCollectionFinder(paths=[collection_root])
         collection_loader._install()  # pylint: disable=protected-access
-
-        # remove all modules under the ansible package
-        list(map(sys.modules.pop, [m for m in sys.modules if m.partition('.')[0] == 'ansible']))
-
     else:
         # do not support collection loading when not testing a collection
         collection_loader = None
 
+    # remove all modules under the ansible package
+    list(map(sys.modules.pop, [m for m in sys.modules if m.partition('.')[0] == ansible.__name__]))
+
+    # pre-load an empty ansible package to prevent unwanted code in __init__.py from loading
+    # this more accurately reflects the environment that AnsiballZ runs modules under
+    # it also avoids issues with imports in the ansible package that are not allowed
+    ansible_module = types.ModuleType(ansible.__name__)
+    ansible_module.__file__ = ansible.__file__
+    ansible_module.__path__ = ansible.__path__
+    ansible_module.__package__ = ansible.__package__
+
+    sys.modules[ansible.__name__] = ansible_module
+
     class ImporterAnsibleModuleException(Exception):
         """Exception thrown during initialization of ImporterAnsibleModule."""
 
@@ -133,7 +143,7 @@ def main():
                 if is_name_in_namepace(fullname, ['ansible.module_utils', self.name]):
                     return None  # module_utils and module under test are always allowed
 
-                if os.path.exists(convert_ansible_name_to_absolute_path(fullname)):
+                if any(os.path.exists(candidate_path) for candidate_path in convert_ansible_name_to_absolute_paths(fullname)):
                     return self  # blacklist ansible files that exist
 
                 return None  # ansible file does not exist, do not blacklist
@@ -318,12 +328,15 @@ def main():
         for module in sorted(changed):
             report_message(path, 0, 0, 'reload', 'reloading of "%s" in sys.modules is not supported' % module, messages)
 
-    def convert_ansible_name_to_absolute_path(name):
+    def convert_ansible_name_to_absolute_paths(name):
         """Calculate the module path from the given name.
         :type name: str
-        :rtype: str
+        :rtype: list[str]
         """
-        return os.path.join(ansible_path, name.replace('.', os.path.sep))
+        return [
+            os.path.join(ansible_path, name.replace('.', os.path.sep)),
+            os.path.join(ansible_path, name.replace('.', os.path.sep)) + '.py',
+        ]
 
     def convert_relative_path_to_name(path):
         """Calculate the module name from the given path.