Perfy McPerferton (#58400)

* InventoryManager start of perf improvements

* 0 not 1

* More startswith to [0] improvements

* Remove unused var

* The hash doesn't need to be a string, start as a list, make it into a tuple

* set actually appears faster than frozenset, and these don't need to be frozen

* Cache hosts lists, to avoid extra get_hosts calls, pass to get_vars too

* negligible perf improvement, it could help with memory later

* Try the fast way, fallback to the safe way

* Revert to previous logic, linting fix

* Extend pre-caching to free

* Address test failures

* Hosts are strings

* Fix unit test

* host is a string

* update test assumption

* drop SharedPluginLoaderObj, pre-create a set, instead of 2 comparisons in the list comprehension

* Dedupe code

* Change to _hosts and _hosts_all in get_vars

* Add backwards compat for strategies that don't do set host caches

* Add deprecation message to SharedPluginLoaderObj

* Remove unused SharedPluginLoaderObj import

* Update docs/comments

* Remove debugging

* Indicate what patterh_hash is

* That won't work

* Re-fix tests

* Update _set_hosts_cache to accept the play directly, use without refresh in get_hosts_remaining and get_failed_hosts for backwards compat

* Rename variable to avoid confusion

* On add_host only manipulate _hosts_cache_all

* Add warning docs around _hosts and _hosts_all args
This commit is contained in:
Matt Martz 2019-07-22 14:25:20 -05:00 committed by GitHub
parent 6adf0c581e
commit 284dafe476
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 168 additions and 116 deletions

View file

@ -48,6 +48,18 @@ IGNORED_EXTS = [b'%s$' % to_bytes(re.escape(x)) for x in C.INVENTORY_IGNORE_EXTS
IGNORED = re.compile(b'|'.join(IGNORED_ALWAYS + IGNORED_PATTERNS + IGNORED_EXTS))
PATTERN_WITH_SUBSCRIPT = re.compile(
r'''^
(.+) # A pattern expression ending with...
\[(?: # A [subscript] expression comprising:
(-?[0-9]+)| # A single positive or negative number
([0-9]+)([:-]) # Or an x:y or x: range.
([0-9]*)
)\]
$
''', re.X
)
def order_patterns(patterns):
''' takes a list of patterns and reorders them by modifier to apply them consistently '''
@ -57,9 +69,9 @@ def order_patterns(patterns):
pattern_intersection = []
pattern_exclude = []
for p in patterns:
if p.startswith("!"):
if p[0] == "!":
pattern_exclude.append(p)
elif p.startswith("&"):
elif p[0] == "&":
pattern_intersection.append(p)
elif p:
pattern_regular.append(p)
@ -316,7 +328,7 @@ class InventoryManager(object):
def _match_list(self, items, pattern_str):
# compile patterns
try:
if not pattern_str.startswith('~'):
if not pattern_str[0] == '~':
pattern = re.compile(fnmatch.translate(pattern_str))
else:
pattern = re.compile(pattern_str[1:])
@ -341,41 +353,45 @@ class InventoryManager(object):
# Check if pattern already computed
if isinstance(pattern, list):
pattern_hash = u":".join(pattern)
pattern_list = pattern[:]
else:
pattern_hash = pattern
pattern_list = [pattern]
if pattern_hash:
if pattern_list:
if not ignore_limits and self._subset:
pattern_hash += u":%s" % to_text(self._subset, errors='surrogate_or_strict')
pattern_list.extend(self._subset)
if not ignore_restrictions and self._restriction:
pattern_hash += u":%s" % to_text(self._restriction, errors='surrogate_or_strict')
pattern_list.extend(self._restriction)
# This is only used as a hash key in the self._hosts_patterns_cache dict
# a tuple is faster than stringifying
pattern_hash = tuple(pattern_list)
if pattern_hash not in self._hosts_patterns_cache:
patterns = split_host_pattern(pattern)
hosts = self._evaluate_patterns(patterns)
hosts[:] = self._evaluate_patterns(patterns)
# mainly useful for hostvars[host] access
if not ignore_limits and self._subset:
# exclude hosts not in a subset, if defined
subset_uuids = [s._uuid for s in self._evaluate_patterns(self._subset)]
hosts = [h for h in hosts if h._uuid in subset_uuids]
subset_uuids = set(s._uuid for s in self._evaluate_patterns(self._subset))
hosts[:] = [h for h in hosts if h._uuid in subset_uuids]
if not ignore_restrictions and self._restriction:
# exclude hosts mentioned in any restriction (ex: failed hosts)
hosts = [h for h in hosts if h.name in self._restriction]
hosts[:] = [h for h in hosts if h.name in self._restriction]
self._hosts_patterns_cache[pattern_hash] = deduplicate_list(hosts)
# sort hosts list if needed (should only happen when called from strategy)
if order in ['sorted', 'reverse_sorted']:
hosts = sorted(self._hosts_patterns_cache[pattern_hash][:], key=attrgetter('name'), reverse=(order == 'reverse_sorted'))
hosts[:] = sorted(self._hosts_patterns_cache[pattern_hash][:], key=attrgetter('name'), reverse=(order == 'reverse_sorted'))
elif order == 'reverse_inventory':
hosts = self._hosts_patterns_cache[pattern_hash][::-1]
hosts[:] = self._hosts_patterns_cache[pattern_hash][::-1]
else:
hosts = self._hosts_patterns_cache[pattern_hash][:]
hosts[:] = self._hosts_patterns_cache[pattern_hash][:]
if order == 'shuffle':
shuffle(hosts)
elif order not in [None, 'inventory']:
@ -398,12 +414,15 @@ class InventoryManager(object):
hosts.append(self._inventory.get_host(p))
else:
that = self._match_one_pattern(p)
if p.startswith("!"):
hosts = [h for h in hosts if h not in frozenset(that)]
elif p.startswith("&"):
hosts = [h for h in hosts if h in frozenset(that)]
if p[0] == "!":
that = set(that)
hosts = [h for h in hosts if h not in that]
elif p[0] == "&":
that = set(that)
hosts = [h for h in hosts if h in that]
else:
hosts.extend([h for h in that if h.name not in frozenset([y.name for y in hosts])])
existing_hosts = set(y.name for y in hosts)
hosts.extend([h for h in that if h.name not in existing_hosts])
return hosts
def _match_one_pattern(self, pattern):
@ -444,7 +463,7 @@ class InventoryManager(object):
Duplicate matches are always eliminated from the results.
"""
if pattern.startswith("&") or pattern.startswith("!"):
if pattern[0] in ("&", "!"):
pattern = pattern[1:]
if pattern not in self._pattern_cache:
@ -469,27 +488,15 @@ class InventoryManager(object):
"""
# Do not parse regexes for enumeration info
if pattern.startswith('~'):
if pattern[0] == '~':
return (pattern, None)
# We want a pattern followed by an integer or range subscript.
# (We can't be more restrictive about the expression because the
# fnmatch semantics permit [\[:\]] to occur.)
pattern_with_subscript = re.compile(
r'''^
(.+) # A pattern expression ending with...
\[(?: # A [subscript] expression comprising:
(-?[0-9]+)| # A single positive or negative number
([0-9]+)([:-]) # Or an x:y or x: range.
([0-9]*)
)\]
$
''', re.X
)
subscript = None
m = pattern_with_subscript.match(pattern)
m = PATTERN_WITH_SUBSCRIPT.match(pattern)
if m:
(pattern, idx, start, sep, end) = m.groups()
if idx:
@ -535,7 +542,7 @@ class InventoryManager(object):
results.extend(self._inventory.groups[groupname].get_hosts())
# check hosts if no groups matched or it is a regex/glob pattern
if not matching_groups or pattern.startswith('~') or any(special in pattern for special in ('.', '?', '*', '[')):
if not matching_groups or pattern[0] == '~' or any(special in pattern for special in ('.', '?', '*', '[')):
# pattern might match host
matching_hosts = self._match_list(self._inventory.hosts, pattern)
if matching_hosts:
@ -585,7 +592,7 @@ class InventoryManager(object):
return
elif not isinstance(restriction, list):
restriction = [restriction]
self._restriction = [h.name for h in restriction]
self._restriction = set(to_text(h.name) for h in restriction)
def subset(self, subset_pattern):
"""
@ -601,12 +608,12 @@ class InventoryManager(object):
results = []
# allow Unix style @filename data
for x in subset_patterns:
if x.startswith("@"):
if x[0] == "@":
fd = open(x[1:])
results.extend([l.strip() for l in fd.read().split("\n")])
results.extend([to_text(l.strip()) for l in fd.read().split("\n")])
fd.close()
else:
results.append(x)
results.append(to_text(x))
self._subset = results
def remove_restriction(self):

View file

@ -10,7 +10,6 @@ from ansible import constants as C
from ansible.plugins.callback import CallbackBase
from ansible.utils.color import colorize, hostcolor
from ansible.template import Templar
from ansible.plugins.strategy import SharedPluginLoaderObj
from ansible.playbook.task_include import TaskInclude
DOCUMENTATION = '''

View file

@ -45,7 +45,7 @@ from ansible.module_utils.connection import Connection, ConnectionError
from ansible.playbook.helpers import load_list_of_blocks
from ansible.playbook.included_file import IncludedFile
from ansible.playbook.task_include import TaskInclude
from ansible.plugins.loader import action_loader, connection_loader, filter_loader, lookup_loader, module_loader, test_loader
from ansible.plugins import loader as plugin_loader
from ansible.template import Templar
from ansible.utils.display import Display
from ansible.utils.vars import combine_vars
@ -60,21 +60,12 @@ class StrategySentinel:
pass
# TODO: this should probably be in the plugins/__init__.py, with
# a smarter mechanism to set all of the attributes based on
# the loaders created there
class SharedPluginLoaderObj:
def SharedPluginLoaderObj():
'''This only exists for backwards compat, do not use.
'''
A simple object to make pass the various plugin loaders to
the forked processes over the queue easier
'''
def __init__(self):
self.action_loader = action_loader
self.connection_loader = connection_loader
self.filter_loader = filter_loader
self.test_loader = test_loader
self.lookup_loader = lookup_loader
self.module_loader = module_loader
display.deprecated('SharedPluginLoaderObj is deprecated, please directly use ansible.plugins.loader',
version='2.11')
return plugin_loader
_sentinel = StrategySentinel()
@ -207,8 +198,29 @@ class StrategyBase:
# play completion
self._active_connections = dict()
# Caches for get_host calls, to avoid calling excessively
# These values should be set at the top of the ``run`` method of each
# strategy plugin. Use ``_set_hosts_cache`` to set these values
self._hosts_cache = []
self._hosts_cache_all = []
self.debugger_active = C.ENABLE_TASK_DEBUGGER
def _set_hosts_cache(self, play, refresh=True):
"""Responsible for setting _hosts_cache and _hosts_cache_all
See comment in ``__init__`` for the purpose of these caches
"""
if not refresh and all((self._hosts_cache, self._hosts_cache_all)):
return
if Templar(None).is_template(play.hosts):
_pattern = 'all'
else:
_pattern = play.hosts or 'all'
self._hosts_cache_all = [h.name for h in self._inventory.get_hosts(pattern=_pattern, ignore_restrictions=True)]
self._hosts_cache = [h.name for h in self._inventory.get_hosts(play.hosts, order=play.order)]
def cleanup(self):
# close active persistent connections
for sock in itervalues(self._active_connections):
@ -227,8 +239,12 @@ class StrategyBase:
# This should be safe, as everything should be ITERATING_COMPLETE by
# this point, though the strategy may not advance the hosts itself.
inv_hosts = self._inventory.get_hosts(iterator._play.hosts, order=iterator._play.order)
[iterator.get_next_task_for_host(host) for host in inv_hosts if host.name not in self._tqm._unreachable_hosts]
for host in self._hosts_cache:
if host not in self._tqm._unreachable_hosts:
try:
iterator.get_next_task_for_host(self._inventory.hosts[host])
except KeyError:
iterator.get_next_task_for_host(self._inventory.get_host(host))
# save the failed/unreachable hosts, as the run_handlers()
# method will clear that information during its execution
@ -258,19 +274,21 @@ class StrategyBase:
return self._tqm.RUN_OK
def get_hosts_remaining(self, play):
return [host for host in self._inventory.get_hosts(play.hosts)
if host.name not in self._tqm._failed_hosts and host.name not in self._tqm._unreachable_hosts]
self._set_hosts_cache(play, refresh=False)
ignore = set(self._tqm._failed_hosts).union(self._tqm._unreachable_hosts)
return [host for host in self._hosts_cache if host not in ignore]
def get_failed_hosts(self, play):
return [host for host in self._inventory.get_hosts(play.hosts) if host.name in self._tqm._failed_hosts]
self._set_hosts_cache(play, refresh=False)
return [host for host in self._hosts_cache if host in self._tqm._failed_hosts]
def add_tqm_variables(self, vars, play):
'''
Base class method to add extra variables/information to the list of task
vars sent through the executor engine regarding the task queue manager state.
'''
vars['ansible_current_hosts'] = [h.name for h in self.get_hosts_remaining(play)]
vars['ansible_failed_hosts'] = [h.name for h in self.get_failed_hosts(play)]
vars['ansible_current_hosts'] = self.get_hosts_remaining(play)
vars['ansible_failed_hosts'] = self.get_failed_hosts(play)
def _queue_task(self, host, task, task_vars, play_context):
''' handles queueing the task up to be sent to a worker '''
@ -294,11 +312,6 @@ class StrategyBase:
# and then queue the new task
try:
# create a dummy object with plugin loaders set as an easier
# way to share them with the forked processes
shared_loader_obj = SharedPluginLoaderObj()
queued = False
starting_worker = self._cur_worker
while True:
@ -311,7 +324,7 @@ class StrategyBase:
'play_context': play_context
}
worker_prc = WorkerProcess(self._final_q, task_vars, host, task, play_context, self._loader, self._variable_manager, shared_loader_obj)
worker_prc = WorkerProcess(self._final_q, task_vars, host, task, play_context, self._loader, self._variable_manager, plugin_loader)
self._workers[self._cur_worker] = worker_prc
self._tqm.send_callback('v2_runner_on_start', host, task)
worker_prc.start()
@ -334,24 +347,19 @@ class StrategyBase:
def get_task_hosts(self, iterator, task_host, task):
if task.run_once:
host_list = [host for host in self._inventory.get_hosts(iterator._play.hosts) if host.name not in self._tqm._unreachable_hosts]
host_list = [host for host in self._hosts_cache if host not in self._tqm._unreachable_hosts]
else:
host_list = [task_host]
host_list = [task_host.name]
return host_list
def get_delegated_hosts(self, result, task):
host_name = result.get('_ansible_delegated_vars', {}).get('ansible_delegated_host', None)
if host_name is not None:
actual_host = self._inventory.get_host(host_name)
if actual_host is None:
actual_host = Host(name=host_name)
else:
actual_host = Host(name=task.delegate_to)
return [actual_host]
return [host_name or task.delegate_to]
def get_handler_templar(self, handler_task, iterator):
handler_vars = self._variable_manager.get_vars(play=iterator._play, task=handler_task)
handler_vars = self._variable_manager.get_vars(play=iterator._play, task=handler_task,
_hosts=self._hosts_cache,
_hosts_all=self._hosts_cache_all)
return Templar(loader=self._loader, variables=handler_vars)
@debug_closure
@ -703,6 +711,7 @@ class StrategyBase:
# Check if host in inventory, add if not
if host_name not in self._inventory.hosts:
self._inventory.add_host(host_name, 'all')
self._hosts_cache_all.append(host_name)
new_host = self._inventory.hosts.get(host_name)
# Set/update the vars for this host
@ -882,7 +891,7 @@ class StrategyBase:
bypass_host_loop = False
try:
action = action_loader.get(handler.action, class_only=True)
action = plugin_loader.action_loader.get(handler.action, class_only=True)
if getattr(action, 'BYPASS_HOST_LOOP', False):
bypass_host_loop = True
except KeyError:
@ -893,7 +902,8 @@ class StrategyBase:
host_results = []
for host in notified_hosts:
if not iterator.is_failed(host) or iterator._play.force_handlers:
task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=handler)
task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=handler,
_hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all)
self.add_tqm_variables(task_vars, play=iterator._play)
templar = Templar(loader=self._loader, variables=task_vars)
if not handler.cached_name:
@ -993,7 +1003,8 @@ class StrategyBase:
meta_action = task.args.get('_raw_params')
def _evaluate_conditional(h):
all_vars = self._variable_manager.get_vars(play=iterator._play, host=h, task=task)
all_vars = self._variable_manager.get_vars(play=iterator._play, host=h, task=task,
_hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all)
templar = Templar(loader=self._loader, variables=all_vars)
return task.evaluate_conditional(templar, all_vars)
@ -1015,6 +1026,7 @@ class StrategyBase:
if task.when:
self._cond_not_supported_warn(meta_action)
self._inventory.refresh_inventory()
self._set_hosts_cache(iterator._play)
msg = "inventory successfully refreshed"
elif meta_action == 'clear_facts':
if _evaluate_conditional(target_host):
@ -1047,7 +1059,8 @@ class StrategyBase:
skipped = True
msg = "end_host conditional evaluated to false, continuing execution for %s" % target_host.name
elif meta_action == 'reset_connection':
all_vars = self._variable_manager.get_vars(play=iterator._play, host=target_host, task=task)
all_vars = self._variable_manager.get_vars(play=iterator._play, host=target_host, task=task,
_hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all)
templar = Templar(loader=self._loader, variables=all_vars)
# apply the given task's information to the connection info,
@ -1075,7 +1088,7 @@ class StrategyBase:
connection = Connection(self._active_connections[target_host])
del self._active_connections[target_host]
else:
connection = connection_loader.get(play_context.connection, play_context, os.devnull)
connection = plugin_loader.connection_loader.get(play_context.connection, play_context, os.devnull)
play_context.set_attributes_from_plugin(connection)
if connection:
@ -1104,9 +1117,12 @@ class StrategyBase:
''' returns list of available hosts for this iterator by filtering out unreachables '''
hosts_left = []
for host in self._inventory.get_hosts(iterator._play.hosts, order=iterator._play.order):
if host.name not in self._tqm._unreachable_hosts:
hosts_left.append(host)
for host in self._hosts_cache:
if host not in self._tqm._unreachable_hosts:
try:
hosts_left.append(self._inventory.hosts[host])
except KeyError:
hosts_left.append(self._inventory.get_host(host))
return hosts_left
def update_active_connections(self, results):

View file

@ -82,6 +82,8 @@ class StrategyModule(StrategyBase):
# start with all workers being counted as being free
workers_free = len(self._workers)
self._set_hosts_cache(iterator._play)
work_to_do = True
while work_to_do and not self._tqm._terminated:
@ -129,7 +131,9 @@ class StrategyModule(StrategyBase):
action = None
display.debug("getting variables", host=host_name)
task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task)
task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task,
_hosts=self._hosts_cache,
_hosts_all=self._hosts_cache_all)
self.add_tqm_variables(task_vars, play=iterator._play)
templar = Templar(loader=self._loader, variables=task_vars)
display.debug("done getting variables", host=host_name)
@ -231,7 +235,9 @@ class StrategyModule(StrategyBase):
continue
for new_block in new_blocks:
task_vars = self._variable_manager.get_vars(play=iterator._play, task=new_block._parent)
task_vars = self._variable_manager.get_vars(play=iterator._play, task=new_block._parent,
_hosts=self._hosts_cache,
_hosts_all=self._hosts_cache_all)
final_block = new_block.filter_tagged_tasks(task_vars)
for host in hosts_left:
if host in included_file._hosts:

View file

@ -205,6 +205,9 @@ class StrategyModule(StrategyBase):
# iterate over each task, while there is one left to run
result = self._tqm.RUN_OK
work_to_do = True
self._set_hosts_cache(iterator._play)
while work_to_do and not self._tqm._terminated:
try:
@ -275,7 +278,8 @@ class StrategyModule(StrategyBase):
break
display.debug("getting variables")
task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task)
task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task,
_hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all)
self.add_tqm_variables(task_vars, play=iterator._play)
templar = Templar(loader=self._loader, variables=task_vars)
display.debug("done getting variables")
@ -358,7 +362,9 @@ class StrategyModule(StrategyBase):
for new_block in new_blocks:
task_vars = self._variable_manager.get_vars(
play=iterator._play,
task=new_block._parent
task=new_block._parent,
_hosts=self._hosts_cache,
_hosts_all=self._hosts_cache_all,
)
display.debug("filtering new block on tags")
final_block = new_block.filter_tagged_tasks(task_vars)

View file

@ -140,7 +140,8 @@ class VariableManager:
def set_inventory(self, inventory):
self._inventory = inventory
def get_vars(self, play=None, host=None, task=None, include_hostvars=True, include_delegate_to=True, use_cache=True):
def get_vars(self, play=None, host=None, task=None, include_hostvars=True, include_delegate_to=True, use_cache=True,
_hosts=None, _hosts_all=None):
'''
Returns the variables, with optional "context" given via the parameters
for the play, host, and task (which could possibly result in different
@ -158,6 +159,10 @@ class VariableManager:
- task->get_vars (if there is a task context)
- vars_cache[host] (if there is a host context)
- extra vars
``_hosts`` and ``_hosts_all`` should be considered private args, with only internal trusted callers relying
on the functionality they provide. These arguments may be removed at a later date without a deprecation
period and without warning.
'''
display.debug("in VariableManager get_vars()")
@ -169,6 +174,8 @@ class VariableManager:
task=task,
include_hostvars=include_hostvars,
include_delegate_to=include_delegate_to,
_hosts=_hosts,
_hosts_all=_hosts_all,
)
# default for all cases
@ -425,7 +432,8 @@ class VariableManager:
display.debug("done with get_vars()")
return all_vars
def _get_magic_variables(self, play, host, task, include_hostvars, include_delegate_to):
def _get_magic_variables(self, play, host, task, include_hostvars, include_delegate_to,
_hosts=None, _hosts_all=None):
'''
Returns a dictionary of so-called "magic" variables in Ansible,
which are special variables we set internally for use.
@ -470,9 +478,14 @@ class VariableManager:
else:
pattern = play.hosts or 'all'
# add the list of hosts in the play, as adjusted for limit/filters
variables['ansible_play_hosts_all'] = [x.name for x in self._inventory.get_hosts(pattern=pattern, ignore_restrictions=True)]
if not _hosts_all:
_hosts_all = [h.name for h in self._inventory.get_hosts(pattern=pattern, ignore_restrictions=True)]
if not _hosts:
_hosts = [h.name for h in self._inventory.get_hosts()]
variables['ansible_play_hosts_all'] = _hosts_all[:]
variables['ansible_play_hosts'] = [x for x in variables['ansible_play_hosts_all'] if x not in play._removed_hosts]
variables['ansible_play_batch'] = [x.name for x in self._inventory.get_hosts() if x.name not in play._removed_hosts]
variables['ansible_play_batch'] = [x for x in _hosts if x not in play._removed_hosts]
# DEPRECATED: play_hosts should be deprecated in favor of ansible_play_batch,
# however this would take work in the templating engine, so for now we'll add both
@ -622,19 +635,19 @@ class VariableManager:
raise AnsibleAssertionError("the type of 'facts' to set for host_facts should be a Mapping but is a %s" % type(facts))
try:
host_cache = self._fact_cache[host.name]
host_cache = self._fact_cache[host]
except KeyError:
# We get to set this as new
host_cache = facts
else:
if not isinstance(host_cache, MutableMapping):
raise TypeError('The object retrieved for {0} must be a MutableMapping but was'
' a {1}'.format(host.name, type(host_cache)))
' a {1}'.format(host, type(host_cache)))
# Update the existing facts
host_cache.update(facts)
# Save the facts back to the backing store
self._fact_cache[host.name] = host_cache
self._fact_cache[host] = host_cache
def set_nonpersistent_facts(self, host, facts):
'''
@ -645,18 +658,17 @@ class VariableManager:
raise AnsibleAssertionError("the type of 'facts' to set for nonpersistent_facts should be a Mapping but is a %s" % type(facts))
try:
self._nonpersistent_fact_cache[host.name].update(facts)
self._nonpersistent_fact_cache[host].update(facts)
except KeyError:
self._nonpersistent_fact_cache[host.name] = facts
self._nonpersistent_fact_cache[host] = facts
def set_host_variable(self, host, varname, value):
'''
Sets a value in the vars_cache for a host.
'''
host_name = host.get_name()
if host_name not in self._vars_cache:
self._vars_cache[host_name] = dict()
if varname in self._vars_cache[host_name] and isinstance(self._vars_cache[host_name][varname], MutableMapping) and isinstance(value, MutableMapping):
self._vars_cache[host_name] = combine_vars(self._vars_cache[host_name], {varname: value})
if host not in self._vars_cache:
self._vars_cache[host] = dict()
if varname in self._vars_cache[host] and isinstance(self._vars_cache[host][varname], MutableMapping) and isinstance(value, MutableMapping):
self._vars_cache[host] = combine_vars(self._vars_cache[host], {varname: value})
else:
self._vars_cache[host_name][varname] = value
self._vars_cache[host][varname] = value

View file

@ -39,7 +39,7 @@ class TestPlaybookCLI(unittest.TestCase):
fake_loader = DictDataLoader({'foobar.yml': ""})
inventory = InventoryManager(loader=fake_loader, sources='testhost,')
variable_manager.set_host_facts(inventory.get_host('testhost'), {'canary': True})
variable_manager.set_host_facts('testhost', {'canary': True})
self.assertTrue('testhost' in variable_manager._fact_cache)
cli._flush_cache(inventory, variable_manager)

View file

@ -147,6 +147,8 @@ class TestStrategyBase(unittest.TestCase):
mock_host.has_hostkey = True
mock_hosts.append(mock_host)
mock_hosts_names = [h.name for h in mock_hosts]
mock_inventory = MagicMock()
mock_inventory.get_hosts.return_value = mock_hosts
@ -158,17 +160,18 @@ class TestStrategyBase(unittest.TestCase):
mock_play.hosts = ["host%02d" % (i + 1) for i in range(0, 5)]
strategy_base = StrategyBase(tqm=mock_tqm)
strategy_base._hosts_cache = strategy_base._hosts_cache_all = mock_hosts_names
mock_tqm._failed_hosts = []
mock_tqm._unreachable_hosts = []
self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), mock_hosts)
self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts])
mock_tqm._failed_hosts = ["host01"]
self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), mock_hosts[1:])
self.assertEqual(strategy_base.get_failed_hosts(play=mock_play), [mock_hosts[0]])
self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts[1:]])
self.assertEqual(strategy_base.get_failed_hosts(play=mock_play), [mock_hosts[0].name])
mock_tqm._unreachable_hosts = ["host02"]
self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), mock_hosts[2:])
self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts[2:]])
strategy_base.cleanup()
@patch.object(WorkerProcess, 'run')

View file

@ -58,18 +58,19 @@ class TestStrategyLinear(unittest.TestCase):
p = Playbook.load('test_play.yml', loader=fake_loader, variable_manager=mock_var_manager)
inventory = MagicMock()
inventory.hosts = {}
hosts = []
for i in range(0, 2):
host = MagicMock()
host.name = host.get_name.return_value = 'host%02d' % i
hosts.append(host)
mock_var_manager._fact_cache['host00'] = dict()
inventory = MagicMock()
inventory.hosts[host.name] = host
inventory.get_hosts.return_value = hosts
inventory.filter_hosts.return_value = hosts
mock_var_manager._fact_cache['host00'] = dict()
play_context = PlayContext(play=p._entries[0])
itr = PlayIterator(
@ -89,6 +90,8 @@ class TestStrategyLinear(unittest.TestCase):
)
tqm._initialize_processes(3)
strategy = StrategyModule(tqm)
strategy._hosts_cache = [h.name for h in hosts]
strategy._hosts_cache_all = [h.name for h in hosts]
# implicit meta: flush_handlers
hosts_left = strategy.get_hosts_left(itr)