diff --git a/lib/ansible/module_utils/facts/collector.py b/lib/ansible/module_utils/facts/collector.py index f55d1db8f2c..47f1bec8e04 100644 --- a/lib/ansible/module_utils/facts/collector.py +++ b/lib/ansible/module_utils/facts/collector.py @@ -36,11 +36,31 @@ import platform from ansible.module_utils.facts import timeout +class CycleFoundInFactDeps(Exception): + '''Indicates there is a cycle in fact collector deps + + If collector-B requires collector-A, and collector-A requires + collector-B, that is a cycle. In that case, there is no ordering + that will satisfy B before A and A and before B. That will cause this + error to be raised. + ''' + pass + + +class UnresolvedFactDep(ValueError): + pass + + +class CollectorNotFoundError(KeyError): + pass + + class BaseFactCollector: _fact_ids = set() _platform = 'Generic' name = None + required_facts = set() def __init__(self, collectors=None, namespace=None): '''Base class for things that collect facts. @@ -216,28 +236,112 @@ def build_fact_id_to_collector_map(collectors_for_platform): return fact_id_to_collector_map, aliases_map -def select_collector_classes(collector_names, all_fact_subsets, all_collector_classes): - # TODO: can be a set() - seen_collector_classes = [] +def select_collector_classes(collector_names, all_fact_subsets): + seen_collector_classes = set() selected_collector_classes = [] - for candidate_collector_class in all_collector_classes: - candidate_collector_name = candidate_collector_class.name - - if candidate_collector_name not in collector_names: - continue - - collector_classes = all_fact_subsets.get(candidate_collector_name, []) - + for collector_name in collector_names: + collector_classes = all_fact_subsets.get(collector_name, []) for collector_class in collector_classes: if collector_class not in seen_collector_classes: selected_collector_classes.append(collector_class) - seen_collector_classes.append(collector_class) + seen_collector_classes.add(collector_class) return selected_collector_classes +def _get_requires_by_collector_name(collector_name, all_fact_subsets): + required_facts = set() + + try: + collector_classes = all_fact_subsets[collector_name] + except KeyError: + raise CollectorNotFoundError('Fact collector "%s" not found' % collector_name) + for collector_class in collector_classes: + required_facts.update(collector_class.required_facts) + return required_facts + + +def find_unresolved_requires(collector_names, all_fact_subsets): + '''Find any collector names that have unresolved requires + + Returns a list of collector names that correspond to collector + classes whose .requires_facts() are not in collector_names. + ''' + unresolved = set() + + for collector_name in collector_names: + required_facts = _get_requires_by_collector_name(collector_name, all_fact_subsets) + for required_fact in required_facts: + if required_fact not in collector_names: + unresolved.add(required_fact) + + return unresolved + + +def resolve_requires(unresolved_requires, all_fact_subsets): + new_names = set() + failed = [] + for unresolved in unresolved_requires: + if unresolved in all_fact_subsets: + new_names.add(unresolved) + else: + failed.append(unresolved) + + if failed: + raise UnresolvedFactDep('unresolved fact dep %s' % ','.join(failed)) + return new_names + + +def build_dep_data(collector_names, all_fact_subsets): + dep_map = defaultdict(set) + for collector_name in collector_names: + collector_deps = set() + for collector in all_fact_subsets[collector_name]: + for dep in collector.required_facts: + collector_deps.add(dep) + dep_map[collector_name] = collector_deps + return dep_map + + +def tsort(dep_map): + sorted_list = [] + + unsorted_map = dep_map.copy() + + while unsorted_map: + acyclic = False + for node, edges in list(unsorted_map.items()): + for edge in edges: + if edge in unsorted_map: + break + else: + acyclic = True + del unsorted_map[node] + sorted_list.append((node, edges)) + + if not acyclic: + raise CycleFoundInFactDeps('Unable to tsort deps, there was a cycle in the graph. sorted=%s' % sorted_list) + + return sorted_list + + +def _solve_deps(collector_names, all_fact_subsets): + unresolved = collector_names.copy() + solutions = collector_names.copy() + + while True: + unresolved = find_unresolved_requires(solutions, all_fact_subsets) + if unresolved == set(): + break + + new_names = resolve_requires(unresolved, all_fact_subsets) + solutions.update(new_names) + + return solutions + + def collector_classes_from_gather_subset(all_collector_classes=None, valid_subsets=None, minimal_gather_subset=None, @@ -283,8 +387,14 @@ def collector_classes_from_gather_subset(all_collector_classes=None, aliases_map=aliases_map, platform_info=platform_info) - selected_collector_classes = select_collector_classes(collector_names, - all_fact_subsets, - all_collector_classes) + complete_collector_names = _solve_deps(collector_names, all_fact_subsets) + + dep_map = build_dep_data(complete_collector_names, all_fact_subsets) + + ordered_deps = tsort(dep_map) + ordered_collector_names = [x[0] for x in ordered_deps] + + selected_collector_classes = select_collector_classes(ordered_collector_names, + all_fact_subsets) return selected_collector_classes diff --git a/lib/ansible/module_utils/facts/hardware/hpux.py b/lib/ansible/module_utils/facts/hardware/hpux.py index e81adcbf244..ae72ed8e486 100644 --- a/lib/ansible/module_utils/facts/hardware/hpux.py +++ b/lib/ansible/module_utils/facts/hardware/hpux.py @@ -161,3 +161,5 @@ class HPUXHardware(Hardware): class HPUXHardwareCollector(HardwareCollector): _fact_class = HPUXHardware _platform = 'HP-UX' + + required_facts = set(['platform', 'distribution']) diff --git a/lib/ansible/module_utils/facts/hardware/linux.py b/lib/ansible/module_utils/facts/hardware/linux.py index 71b8d1eefe2..e781419d98f 100644 --- a/lib/ansible/module_utils/facts/hardware/linux.py +++ b/lib/ansible/module_utils/facts/hardware/linux.py @@ -710,3 +710,5 @@ class LinuxHardware(Hardware): class LinuxHardwareCollector(HardwareCollector): _platform = 'Linux' _fact_class = LinuxHardware + + required_facts = set(['platform']) diff --git a/lib/ansible/module_utils/facts/hardware/sunos.py b/lib/ansible/module_utils/facts/hardware/sunos.py index 5e62dc2fc35..888fb8dfacd 100644 --- a/lib/ansible/module_utils/facts/hardware/sunos.py +++ b/lib/ansible/module_utils/facts/hardware/sunos.py @@ -263,3 +263,5 @@ class SunOSHardware(Hardware): class SunOSHardwareCollector(HardwareCollector): _fact_class = SunOSHardware _platform = 'SunOS' + + required_facts = set(['platform']) diff --git a/lib/ansible/module_utils/facts/network/linux.py b/lib/ansible/module_utils/facts/network/linux.py index eddb1df5bd3..54e2745a2ee 100644 --- a/lib/ansible/module_utils/facts/network/linux.py +++ b/lib/ansible/module_utils/facts/network/linux.py @@ -310,3 +310,4 @@ class LinuxNetwork(Network): class LinuxNetworkCollector(NetworkCollector): _platform = 'Linux' _fact_class = LinuxNetwork + required_facts = set(['distribution', 'platform']) diff --git a/lib/ansible/module_utils/facts/system/distribution.py b/lib/ansible/module_utils/facts/system/distribution.py index 90dc6356e6c..565fb17c7fa 100644 --- a/lib/ansible/module_utils/facts/system/distribution.py +++ b/lib/ansible/module_utils/facts/system/distribution.py @@ -579,7 +579,8 @@ class DistributionFactCollector(BaseFactCollector): name = 'distribution' _fact_ids = set(['distribution_version', 'distribution_release', - 'distribution_major_version']) + 'distribution_major_version', + 'os_family']) def collect(self, module=None, collected_facts=None): collected_facts = collected_facts or {} diff --git a/lib/ansible/module_utils/facts/system/platform.py b/lib/ansible/module_utils/facts/system/platform.py index 74e8e67c71d..5d503d0d29f 100644 --- a/lib/ansible/module_utils/facts/system/platform.py +++ b/lib/ansible/module_utils/facts/system/platform.py @@ -35,6 +35,7 @@ class PlatformFactCollector(BaseFactCollector): 'kernel', 'machine', 'python_version', + 'architecture', 'machine_id']) def collect(self, module=None, collected_facts=None): diff --git a/lib/ansible/module_utils/facts/system/service_mgr.py b/lib/ansible/module_utils/facts/system/service_mgr.py index d983604cdb2..cb044fbf72c 100644 --- a/lib/ansible/module_utils/facts/system/service_mgr.py +++ b/lib/ansible/module_utils/facts/system/service_mgr.py @@ -38,6 +38,7 @@ if platform.system() != 'SunOS': class ServiceMgrFactCollector(BaseFactCollector): name = 'service_mgr' _fact_ids = set() + required_facts = set(['platform', 'distribution']) @staticmethod def is_systemd_managed(module): diff --git a/test/integration/targets/gathering_facts/cache_plugins/none.py b/test/integration/targets/gathering_facts/cache_plugins/none.py new file mode 100644 index 00000000000..5681dee0e49 --- /dev/null +++ b/test/integration/targets/gathering_facts/cache_plugins/none.py @@ -0,0 +1,50 @@ +# (c) 2014, Brian Coca, Josh Drake, et al +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.plugins.cache import BaseCacheModule + +DOCUMENTATION = ''' + cache: none + short_description: write-only cache (no cache) + description: + - No caching at all + version_added: historical + author: core team (@ansible-core) +''' + + +class CacheModule(BaseCacheModule): + def __init__(self, *args, **kwargs): + self.empty = {} + + def get(self, key): + return self.empty.get(key) + + def set(self, key, value): + return value + + def keys(self): + return self.empty.keys() + + def contains(self, key): + return key in self.empty + + def delete(self, key): + del self.emtpy[key] + + def flush(self): + self.empty = {} + + def copy(self): + return self.empty.copy() + + def __getstate__(self): + return self.copy() + + def __setstate__(self, data): + self.empty = data diff --git a/test/integration/targets/gathering_facts/runme.sh b/test/integration/targets/gathering_facts/runme.sh index 925910b2268..e4c7b3844a1 100755 --- a/test/integration/targets/gathering_facts/runme.sh +++ b/test/integration/targets/gathering_facts/runme.sh @@ -2,4 +2,6 @@ set -eux +# ANSIBLE_CACHE_PLUGINS=cache_plugins/ ANSIBLE_CACHE_PLUGIN=none ansible-playbook test_gathering_facts.yml -i ../../inventory -v "$@" ansible-playbook test_gathering_facts.yml -i ../../inventory -v "$@" +#ANSIBLE_CACHE_PLUGIN=base ansible-playbook test_gathering_facts.yml -i ../../inventory -v "$@" diff --git a/test/integration/targets/gathering_facts/test_gathering_facts.yml b/test/integration/targets/gathering_facts/test_gathering_facts.yml index 6769bf9ce70..9dd0960dba6 100644 --- a/test/integration/targets/gathering_facts/test_gathering_facts.yml +++ b/test/integration/targets/gathering_facts/test_gathering_facts.yml @@ -112,7 +112,7 @@ tasks: - setup: filter: "*env*" - register: fact_results + # register: fact_results - name: Test that retrieving all facts filtered to env works assert: @@ -129,7 +129,7 @@ tasks: - setup: filter: "ansible_user_id" - register: fact_results + # register: fact_results - name: Test that retrieving all facts filtered to specific fact ansible_user_id works assert: @@ -148,7 +148,7 @@ tasks: - setup: filter: "*" - register: fact_results + # register: fact_results - name: Test that retrieving all facts filtered to splat assert: @@ -165,7 +165,7 @@ tasks: - setup: filter: "" - register: fact_results + # register: fact_results - name: Test that retrieving all facts filtered to empty filter_spec works assert: diff --git a/test/units/module_utils/facts/test_collector.py b/test/units/module_utils/facts/test_collector.py index 2fbcfe6f0eb..76f37f3cebf 100644 --- a/test/units/module_utils/facts/test_collector.py +++ b/test/units/module_utils/facts/test_collector.py @@ -21,6 +21,7 @@ from __future__ import (absolute_import, division) __metaclass__ = type from collections import defaultdict +import pprint # for testing from ansible.compat.tests import unittest @@ -54,34 +55,22 @@ class TestFindCollectorsForPlatform(unittest.TestCase): class TestSelectCollectorNames(unittest.TestCase): + + def _assert_equal_detail(self, obj1, obj2, msg=None): + msg = 'objects are not equal\n%s\n\n!=\n\n%s' % (pprint.pformat(obj1), pprint.pformat(obj2)) + return self.assertEqual(obj1, obj2, msg) + def test(self): - collector_names = set(['distribution', 'all_ipv4_addresses', - 'local', 'pkg_mgr']) + collector_names = ['distribution', 'all_ipv4_addresses', + 'local', 'pkg_mgr'] all_fact_subsets = self._all_fact_subsets() - all_collector_classes = self._all_collector_classes() res = collector.select_collector_classes(collector_names, - all_fact_subsets, - all_collector_classes) + all_fact_subsets) expected = [default_collectors.DistributionFactCollector, default_collectors.PkgMgrFactCollector] - self.assertEqual(res, expected) - - def test_reverse(self): - collector_names = set(['distribution', 'all_ipv4_addresses', - 'local', 'pkg_mgr']) - all_fact_subsets = self._all_fact_subsets() - all_collector_classes = self._all_collector_classes() - all_collector_classes.reverse() - res = collector.select_collector_classes(collector_names, - all_fact_subsets, - all_collector_classes) - - expected = [default_collectors.PkgMgrFactCollector, - default_collectors.DistributionFactCollector] - - self.assertEqual(res, expected) + self._assert_equal_detail(res, expected) def test_default_collectors(self): platform_info = {'system': 'Generic'} @@ -95,14 +84,22 @@ class TestSelectCollectorNames(unittest.TestCase): collector_names = collector.get_collector_names(valid_subsets=all_valid_subsets, aliases_map=aliases_map, platform_info=platform_info) - collector.select_collector_classes(collector_names, - all_fact_subsets, - default_collectors.collectors) + complete_collector_names = collector._solve_deps(collector_names, all_fact_subsets) - def _all_collector_classes(self): - return [default_collectors.DistributionFactCollector, - default_collectors.PkgMgrFactCollector, - default_collectors.LinuxNetworkCollector] + dep_map = collector.build_dep_data(complete_collector_names, all_fact_subsets) + + ordered_deps = collector.tsort(dep_map) + ordered_collector_names = [x[0] for x in ordered_deps] + + res = collector.select_collector_classes(ordered_collector_names, + all_fact_subsets) + + self.assertTrue(res.index(default_collectors.ServiceMgrFactCollector) > + res.index(default_collectors.DistributionFactCollector), + res) + self.assertTrue(res.index(default_collectors.ServiceMgrFactCollector) > + res.index(default_collectors.PlatformFactCollector), + res) def _all_fact_subsets(self, data=None): all_fact_subsets = defaultdict(list) @@ -276,30 +273,249 @@ class TestGetCollectorNames(unittest.TestCase): gather_subset=['my_fact', 'not_a_valid_gather_subset']) +class TestFindUnresolvedRequires(unittest.TestCase): + def test(self): + names = ['network', 'virtual', 'env'] + all_fact_subsets = {'env': [default_collectors.EnvFactCollector], + 'network': [default_collectors.LinuxNetworkCollector], + 'virtual': [default_collectors.LinuxVirtualCollector]} + res = collector.find_unresolved_requires(names, all_fact_subsets) + # pprint.pprint(res) + + self.assertIsInstance(res, set) + self.assertEqual(res, set(['platform', 'distribution'])) + + def test_resolved(self): + names = ['network', 'virtual', 'env', 'platform', 'distribution'] + all_fact_subsets = {'env': [default_collectors.EnvFactCollector], + 'network': [default_collectors.LinuxNetworkCollector], + 'distribution': [default_collectors.DistributionFactCollector], + 'platform': [default_collectors.PlatformFactCollector], + 'virtual': [default_collectors.LinuxVirtualCollector]} + res = collector.find_unresolved_requires(names, all_fact_subsets) + # pprint.pprint(res) + + self.assertIsInstance(res, set) + self.assertEqual(res, set()) + + +class TestBuildDepData(unittest.TestCase): + def test(self): + names = ['network', 'virtual', 'env'] + all_fact_subsets = {'env': [default_collectors.EnvFactCollector], + 'network': [default_collectors.LinuxNetworkCollector], + 'virtual': [default_collectors.LinuxVirtualCollector]} + res = collector.build_dep_data(names, all_fact_subsets) + + # pprint.pprint(dict(res)) + self.assertIsInstance(res, defaultdict) + self.assertEqual(dict(res), + {'network': set(['platform', 'distribution']), + 'virtual': set(), + 'env': set()}) + + +class TestSolveDeps(unittest.TestCase): + def test_no_solution(self): + unresolved = set(['required_thing1', 'required_thing2']) + all_fact_subsets = {'env': [default_collectors.EnvFactCollector], + 'network': [default_collectors.LinuxNetworkCollector], + 'virtual': [default_collectors.LinuxVirtualCollector]} + + self.assertRaises(collector.CollectorNotFoundError, + collector._solve_deps, + unresolved, + all_fact_subsets) + + def test(self): + unresolved = set(['env', 'network']) + all_fact_subsets = {'env': [default_collectors.EnvFactCollector], + 'network': [default_collectors.LinuxNetworkCollector], + 'virtual': [default_collectors.LinuxVirtualCollector], + 'platform': [default_collectors.PlatformFactCollector], + 'distribution': [default_collectors.DistributionFactCollector]} + res = collector.resolve_requires(unresolved, all_fact_subsets) + + res = collector._solve_deps(unresolved, all_fact_subsets) + + self.assertIsInstance(res, set) + for goal in unresolved: + self.assertIn(goal, res) + + +class TestResolveRequires(unittest.TestCase): + def test_no_resolution(self): + unresolved = ['required_thing1', 'required_thing2'] + all_fact_subsets = {'env': [default_collectors.EnvFactCollector], + 'network': [default_collectors.LinuxNetworkCollector], + 'virtual': [default_collectors.LinuxVirtualCollector]} + self.assertRaisesRegexp(collector.UnresolvedFactDep, + 'unresolved fact dep.*required_thing2', + collector.resolve_requires, + unresolved, all_fact_subsets) + + def test(self): + unresolved = ['env', 'network'] + all_fact_subsets = {'env': [default_collectors.EnvFactCollector], + 'network': [default_collectors.LinuxNetworkCollector], + 'virtual': [default_collectors.LinuxVirtualCollector]} + res = collector.resolve_requires(unresolved, all_fact_subsets) + for goal in unresolved: + self.assertIn(goal, res) + + def test_exception(self): + unresolved = ['required_thing1'] + all_fact_subsets = {} + try: + collector.resolve_requires(unresolved, all_fact_subsets) + except collector.UnresolvedFactDep as exc: + self.assertIn(unresolved[0], '%s' % exc) + + +class TestTsort(unittest.TestCase): + def test(self): + dep_map = {'network': set(['distribution', 'platform']), + 'virtual': set(), + 'platform': set(['what_platform_wants']), + 'what_platform_wants': set(), + 'network_stuff': set(['network'])} + + res = collector.tsort(dep_map) + # pprint.pprint(res) + + self.assertIsInstance(res, list) + names = [x[0] for x in res] + self.assertTrue(names.index('network_stuff') > names.index('network')) + self.assertTrue(names.index('platform') > names.index('what_platform_wants')) + self.assertTrue(names.index('network') > names.index('platform')) + + def test_cycles(self): + dep_map = {'leaf1': set(), + 'leaf2': set(), + 'node1': set(['node2']), + 'node2': set(['node3']), + 'node3': set(['node1'])} + + self.assertRaises(collector.CycleFoundInFactDeps, + collector.tsort, + dep_map) + + def test_just_nodes(self): + dep_map = {'leaf1': set(), + 'leaf4': set(), + 'leaf3': set(), + 'leaf2': set()} + + res = collector.tsort(dep_map) + self.assertIsInstance(res, list) + names = [x[0] for x in res] + # not a lot to assert here, any order of the + # results is valid + self.assertEqual(set(names), set(dep_map.keys())) + + def test_self_deps(self): + dep_map = {'node1': set(['node1']), + 'node2': set(['node2'])} + self.assertRaises(collector.CycleFoundInFactDeps, + collector.tsort, + dep_map) + + def test_unsolvable(self): + dep_map = {'leaf1': set(), + 'node2': set(['leaf2'])} + + res = collector.tsort(dep_map) + self.assertIsInstance(res, list) + names = [x[0] for x in res] + self.assertEqual(set(names), set(dep_map.keys())) + + def test_chain(self): + dep_map = {'leaf1': set(['leaf2']), + 'leaf2': set(['leaf3']), + 'leaf3': set(['leaf4']), + 'leaf4': set(), + 'leaf5': set(['leaf1'])} + res = collector.tsort(dep_map) + self.assertIsInstance(res, list) + names = [x[0] for x in res] + self.assertEqual(set(names), set(dep_map.keys())) + + def test_multi_pass(self): + dep_map = {'leaf1': set(), + 'leaf2': set(['leaf3', 'leaf1', 'leaf4', 'leaf5']), + 'leaf3': set(['leaf4', 'leaf1']), + 'leaf4': set(['leaf1']), + 'leaf5': set(['leaf1'])} + res = collector.tsort(dep_map) + self.assertIsInstance(res, list) + names = [x[0] for x in res] + self.assertEqual(set(names), set(dep_map.keys())) + self.assertTrue(names.index('leaf1') < names.index('leaf2')) + for leaf in ('leaf2', 'leaf3', 'leaf4', 'leaf5'): + self.assertTrue(names.index('leaf1') < names.index(leaf)) + + class TestCollectorClassesFromGatherSubset(unittest.TestCase): + maxDiff = None + def _classes(self, all_collector_classes=None, valid_subsets=None, minimal_gather_subset=None, gather_subset=None, - gather_timeout=None): + gather_timeout=None, + platform_info=None): + platform_info = platform_info or {'system': 'Linux'} return collector.collector_classes_from_gather_subset(all_collector_classes=all_collector_classes, valid_subsets=valid_subsets, minimal_gather_subset=minimal_gather_subset, gather_subset=gather_subset, - gather_timeout=gather_timeout) + gather_timeout=gather_timeout, + platform_info=platform_info) def test_no_args(self): res = self._classes() self.assertIsInstance(res, list) self.assertEqual(res, []) - def test(self): + def test_not_all(self): res = self._classes(all_collector_classes=default_collectors.collectors, gather_subset=['!all']) self.assertIsInstance(res, list) self.assertEqual(res, []) + def test_all(self): + res = self._classes(all_collector_classes=default_collectors.collectors, + gather_subset=['all']) + self.assertIsInstance(res, list) + + def test_hardware(self): + res = self._classes(all_collector_classes=default_collectors.collectors, + gather_subset=['hardware']) + self.assertIsInstance(res, list) + self.assertIn(default_collectors.PlatformFactCollector, res) + self.assertIn(default_collectors.LinuxHardwareCollector, res) + + self.assertTrue(res.index(default_collectors.LinuxHardwareCollector) > + res.index(default_collectors.PlatformFactCollector)) + + def test_network(self): + res = self._classes(all_collector_classes=default_collectors.collectors, + gather_subset=['network']) + self.assertIsInstance(res, list) + self.assertIn(default_collectors.DistributionFactCollector, res) + self.assertIn(default_collectors.PlatformFactCollector, res) + self.assertIn(default_collectors.LinuxNetworkCollector, res) + + self.assertTrue(res.index(default_collectors.LinuxNetworkCollector) > + res.index(default_collectors.PlatformFactCollector)) + self.assertTrue(res.index(default_collectors.LinuxNetworkCollector) > + res.index(default_collectors.DistributionFactCollector)) + + # self.assertEqual(set(res, [default_collectors.DistributionFactCollector, + # default_collectors.PlatformFactCollector, + # default_collectors.LinuxNetworkCollector]) + def test_env(self): res = self._classes(all_collector_classes=default_collectors.collectors, gather_subset=['env'])