Allow callbacks from forks (#70501)

* POC for supporting callback events that come from the worker

* linting fixes. ci_complete

* fix up units. ci_complete

* Try moving the sentinel put higher. ci_complete

* safeguards. ci_complete

* Move queue killing to terminate

* LINTING. ci_complete

* Subclass Queue, to add helper send_callback method

* Just use _final_q instead of adding another queue and thread

* Revert a few changes

* Add helper for inserting a TaskResult into the _final_q

* Add changelog fragment

* Address rebase issue

* ci_complete

* Add test to assert async poll callback from fork

* Don't use full path

* ci_complete

* Use _results_lock as a context manager

* Add new generic lock decorator, and use it with send_callback
This commit is contained in:
Matt Martz 2020-08-17 10:51:01 -05:00 committed by GitHub
parent 92d59a58c0
commit 5821128995
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 143 additions and 30 deletions

View file

@ -0,0 +1,3 @@
minor_changes:
- callbacks - Add feature allowing forks to send callback events
(https://github.com/ansible/ansible/issues/14681)

View file

@ -166,41 +166,38 @@ class WorkerProcess(multiprocessing_context.Process):
display.debug("done running TaskExecutor() for %s/%s [%s]" % (self._host, self._task, self._task._uuid)) display.debug("done running TaskExecutor() for %s/%s [%s]" % (self._host, self._task, self._task._uuid))
self._host.vars = dict() self._host.vars = dict()
self._host.groups = [] self._host.groups = []
task_result = TaskResult(
# put the result on the result queue
display.debug("sending task result for task %s" % self._task._uuid)
self._final_q.send_task_result(
self._host.name, self._host.name,
self._task._uuid, self._task._uuid,
executor_result, executor_result,
task_fields=self._task.dump_attrs(), task_fields=self._task.dump_attrs(),
) )
# put the result on the result queue
display.debug("sending task result for task %s" % self._task._uuid)
self._final_q.put(task_result)
display.debug("done sending task result for task %s" % self._task._uuid) display.debug("done sending task result for task %s" % self._task._uuid)
except AnsibleConnectionFailure: except AnsibleConnectionFailure:
self._host.vars = dict() self._host.vars = dict()
self._host.groups = [] self._host.groups = []
task_result = TaskResult( self._final_q.send_task_result(
self._host.name, self._host.name,
self._task._uuid, self._task._uuid,
dict(unreachable=True), dict(unreachable=True),
task_fields=self._task.dump_attrs(), task_fields=self._task.dump_attrs(),
) )
self._final_q.put(task_result, block=False)
except Exception as e: except Exception as e:
if not isinstance(e, (IOError, EOFError, KeyboardInterrupt, SystemExit)) or isinstance(e, TemplateNotFound): if not isinstance(e, (IOError, EOFError, KeyboardInterrupt, SystemExit)) or isinstance(e, TemplateNotFound):
try: try:
self._host.vars = dict() self._host.vars = dict()
self._host.groups = [] self._host.groups = []
task_result = TaskResult( self._final_q.send_task_result(
self._host.name, self._host.name,
self._task._uuid, self._task._uuid,
dict(failed=True, exception=to_text(traceback.format_exc()), stdout=''), dict(failed=True, exception=to_text(traceback.format_exc()), stdout=''),
task_fields=self._task.dump_attrs(), task_fields=self._task.dump_attrs(),
) )
self._final_q.put(task_result, block=False)
except Exception: except Exception:
display.debug(u"WORKER EXCEPTION: %s" % to_text(e)) display.debug(u"WORKER EXCEPTION: %s" % to_text(e))
display.debug(u"WORKER TRACEBACK: %s" % to_text(traceback.format_exc())) display.debug(u"WORKER TRACEBACK: %s" % to_text(traceback.format_exc()))

View file

@ -377,14 +377,11 @@ class TaskExecutor:
'msg': 'Failed to template loop_control.label: %s' % to_text(e) 'msg': 'Failed to template loop_control.label: %s' % to_text(e)
}) })
self._final_q.put( self._final_q.send_task_result(
TaskResult(
self._host.name, self._host.name,
self._task._uuid, self._task._uuid,
res, res,
task_fields=task_fields, task_fields=task_fields,
),
block=False,
) )
results.append(res) results.append(res)
del task_vars[loop_var] del task_vars[loop_var]
@ -600,7 +597,6 @@ class TaskExecutor:
if self._task.async_val > 0: if self._task.async_val > 0:
if self._task.poll > 0 and not result.get('skipped') and not result.get('failed'): if self._task.poll > 0 and not result.get('skipped') and not result.get('failed'):
result = self._poll_async_result(result=result, templar=templar, task_vars=vars_copy) result = self._poll_async_result(result=result, templar=templar, task_vars=vars_copy)
# FIXME callback 'v2_runner_on_async_poll' here
# ensure no log is preserved # ensure no log is preserved
result["_ansible_no_log"] = self._play_context.no_log result["_ansible_no_log"] = self._play_context.no_log
@ -672,7 +668,7 @@ class TaskExecutor:
result['_ansible_retry'] = True result['_ansible_retry'] = True
result['retries'] = retries result['retries'] = retries
display.debug('Retrying task, attempt %d of %d' % (attempt, retries)) display.debug('Retrying task, attempt %d of %d' % (attempt, retries))
self._final_q.put(TaskResult(self._host.name, self._task._uuid, result, task_fields=self._task.dump_attrs()), block=False) self._final_q.send_task_result(self._host.name, self._task._uuid, result, task_fields=self._task.dump_attrs())
time.sleep(delay) time.sleep(delay)
self._handler = self._get_action_handler(connection=self._connection, templar=templar) self._handler = self._get_action_handler(connection=self._connection, templar=templar)
else: else:
@ -778,6 +774,15 @@ class TaskExecutor:
raise raise
else: else:
time_left -= self._task.poll time_left -= self._task.poll
self._final_q.send_callback(
'v2_runner_on_async_poll',
TaskResult(
self._host,
async_task,
async_result,
task_fields=self._task.dump_attrs(),
),
)
if int(async_result.get('finished', 0)) != 1: if int(async_result.get('finished', 0)) != 1:
if async_result.get('_ansible_parsed'): if async_result.get('_ansible_parsed'):

View file

@ -21,7 +21,9 @@ __metaclass__ = type
import os import os
import tempfile import tempfile
import threading
import time import time
import multiprocessing.queues
from ansible import constants as C from ansible import constants as C
from ansible import context from ansible import context
@ -29,7 +31,7 @@ from ansible.errors import AnsibleError
from ansible.executor.play_iterator import PlayIterator from ansible.executor.play_iterator import PlayIterator
from ansible.executor.stats import AggregateStats from ansible.executor.stats import AggregateStats
from ansible.executor.task_result import TaskResult from ansible.executor.task_result import TaskResult
from ansible.module_utils.six import string_types from ansible.module_utils.six import PY3, string_types
from ansible.module_utils._text import to_text, to_native from ansible.module_utils._text import to_text, to_native
from ansible.playbook.play_context import PlayContext from ansible.playbook.play_context import PlayContext
from ansible.plugins.loader import callback_loader, strategy_loader, module_loader from ansible.plugins.loader import callback_loader, strategy_loader, module_loader
@ -38,6 +40,7 @@ from ansible.template import Templar
from ansible.vars.hostvars import HostVars from ansible.vars.hostvars import HostVars
from ansible.vars.reserved import warn_if_reserved from ansible.vars.reserved import warn_if_reserved
from ansible.utils.display import Display from ansible.utils.display import Display
from ansible.utils.lock import lock_decorator
from ansible.utils.multiprocessing import context as multiprocessing_context from ansible.utils.multiprocessing import context as multiprocessing_context
@ -46,6 +49,36 @@ __all__ = ['TaskQueueManager']
display = Display() display = Display()
class CallbackSend:
def __init__(self, method_name, *args, **kwargs):
self.method_name = method_name
self.args = args
self.kwargs = kwargs
class FinalQueue(multiprocessing.queues.Queue):
def __init__(self, *args, **kwargs):
if PY3:
kwargs['ctx'] = multiprocessing_context
super(FinalQueue, self).__init__(*args, **kwargs)
def send_callback(self, method_name, *args, **kwargs):
self.put(
CallbackSend(method_name, *args, **kwargs),
block=False
)
def send_task_result(self, *args, **kwargs):
if isinstance(args[0], TaskResult):
tr = args[0]
else:
tr = TaskResult(*args, **kwargs)
self.put(
tr,
block=False
)
class TaskQueueManager: class TaskQueueManager:
''' '''
@ -95,10 +128,12 @@ class TaskQueueManager:
self._unreachable_hosts = dict() self._unreachable_hosts = dict()
try: try:
self._final_q = multiprocessing_context.Queue() self._final_q = FinalQueue()
except OSError as e: except OSError as e:
raise AnsibleError("Unable to use multiprocessing, this is normally caused by lack of access to /dev/shm: %s" % to_native(e)) raise AnsibleError("Unable to use multiprocessing, this is normally caused by lack of access to /dev/shm: %s" % to_native(e))
self._callback_lock = threading.Lock()
# A temporary file (opened pre-fork) used by connection # A temporary file (opened pre-fork) used by connection
# plugins for inter-process locking. # plugins for inter-process locking.
self._connection_lockfile = tempfile.TemporaryFile() self._connection_lockfile = tempfile.TemporaryFile()
@ -316,6 +351,7 @@ class TaskQueueManager:
defunct = True defunct = True
return defunct return defunct
@lock_decorator(attr='_callback_lock')
def send_callback(self, method_name, *args, **kwargs): def send_callback(self, method_name, *args, **kwargs):
for callback_plugin in [self._stdout_callback] + self._callback_plugins: for callback_plugin in [self._stdout_callback] + self._callback_plugins:
# a plugin that set self.disabled to True will not be called # a plugin that set self.disabled to True will not be called

View file

@ -421,6 +421,16 @@ class CallbackModule(CallbackBase):
msg += "Result was: %s" % self._dump_results(result._result) msg += "Result was: %s" % self._dump_results(result._result)
self._display.display(msg, color=C.COLOR_DEBUG) self._display.display(msg, color=C.COLOR_DEBUG)
def v2_runner_on_async_poll(self, result):
host = result._host.get_name()
jid = result._result.get('ansible_job_id')
started = result._result.get('started')
finished = result._result.get('finished')
self._display.display(
'ASYNC POLL on %s: jid=%s started=%s finished=%s' % (host, jid, started, finished),
color=C.COLOR_DEBUG
)
def v2_playbook_on_notify(self, handler, host): def v2_playbook_on_notify(self, handler, host):
if self._display.verbosity > 1: if self._display.verbosity > 1:
self._display.display("NOTIFIED HANDLER %s for %s" % (handler.get_name(), host), color=C.COLOR_VERBOSE, screen_only=True) self._display.display("NOTIFIED HANDLER %s for %s" % (handler.get_name(), host), color=C.COLOR_VERBOSE, screen_only=True)

View file

@ -37,6 +37,7 @@ from ansible.errors import AnsibleError, AnsibleFileNotFound, AnsibleParserError
from ansible.executor import action_write_locks from ansible.executor import action_write_locks
from ansible.executor.process.worker import WorkerProcess from ansible.executor.process.worker import WorkerProcess
from ansible.executor.task_result import TaskResult from ansible.executor.task_result import TaskResult
from ansible.executor.task_queue_manager import CallbackSend
from ansible.module_utils.six.moves import queue as Queue from ansible.module_utils.six.moves import queue as Queue
from ansible.module_utils.six import iteritems, itervalues, string_types from ansible.module_utils.six import iteritems, itervalues, string_types
from ansible.module_utils._text import to_text from ansible.module_utils._text import to_text
@ -92,8 +93,10 @@ def results_thread_main(strategy):
result = strategy._final_q.get() result = strategy._final_q.get()
if isinstance(result, StrategySentinel): if isinstance(result, StrategySentinel):
break break
else: elif isinstance(result, CallbackSend):
strategy._results_lock.acquire() strategy._tqm.send_callback(result.method_name, *result.args, **result.kwargs)
elif isinstance(result, TaskResult):
with strategy._results_lock:
# only handlers have the listen attr, so this must be a handler # only handlers have the listen attr, so this must be a handler
# we split up the results into two queues here to make sure # we split up the results into two queues here to make sure
# handler and regular result processing don't cross wires # handler and regular result processing don't cross wires
@ -101,7 +104,8 @@ def results_thread_main(strategy):
strategy._handler_results.append(result) strategy._handler_results.append(result)
else: else:
strategy._results.append(result) strategy._results.append(result)
strategy._results_lock.release() else:
display.warning('Received an invalid object (%s) in the result queue: %r' % (type(result), result))
except (IOError, EOFError): except (IOError, EOFError):
break break
except Queue.Empty: except Queue.Empty:

43
lib/ansible/utils/lock.py Normal file
View file

@ -0,0 +1,43 @@
# Copyright (c) 2020 Matt Martz <matt@sivel.net>
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from functools import wraps
def lock_decorator(attr='missing_lock_attr', lock=None):
'''This decorator is a generic implementation that allows you
to either use a pre-defined instance attribute as the location
of the lock, or to explicitly pass a lock object.
This code was implemented with ``threading.Lock`` in mind, but
may work with other locks, assuming that they function as
context managers.
When using ``attr``, the assumption is the first argument to
the wrapped method, is ``self`` or ``cls``.
Examples:
@lock_decorator(attr='_callback_lock')
def send_callback(...):
@lock_decorator(lock=threading.Lock())
def some_method(...):
'''
def outer(func):
@wraps(func)
def inner(*args, **kwargs):
# Python2 doesn't have ``nonlocal``
# assign the actual lock to ``_lock``
if lock is None:
_lock = getattr(args[0], attr)
else:
_lock = lock
with _lock:
return func(*args, **kwargs)
return inner
return outer

View file

@ -0,0 +1,7 @@
- hosts: localhost
gather_facts: false
tasks:
- name: Async poll callback test
command: sleep 5
async: 6
poll: 1

View file

@ -298,3 +298,11 @@
{{ ansible_python_interpreter|default('/usr/bin/python') }} -c 'import os; os.fdopen(os.dup(0), "r")' {{ ansible_python_interpreter|default('/usr/bin/python') }} -c 'import os; os.fdopen(os.dup(0), "r")'
async: 1 async: 1
poll: 1 poll: 1
- name: run async poll callback test playbook
command: ansible-playbook {{ role_path }}/callback_test.yml
register: callback_output
- assert:
that:
- '"ASYNC POLL on localhost" in callback_output.stdout'