Improve interlaced output prevention when asking for host key approval.

This commit is contained in:
Michael DeHaan 2013-07-04 18:17:45 -04:00
parent c55adc9ac9
commit 2cb7c30834
4 changed files with 91 additions and 73 deletions

View file

@ -83,11 +83,23 @@ def log_lockfile():
LOG_LOCK = open(log_lockfile(), 'w') LOG_LOCK = open(log_lockfile(), 'w')
def log_flock(): def log_flock(runner):
fcntl.flock(LOG_LOCK, fcntl.LOCK_EX) fcntl.lockf(LOG_LOCK, fcntl.LOCK_EX)
if runner is not None:
try:
fcntl.lockf(runner.output_lockfile, fcntl.LOCK_EX)
except OSError, e:
# already got closed?
pass
def log_unflock(): def log_unflock(runner):
fcntl.flock(LOG_LOCK, fcntl.LOCK_UN) fcntl.lockf(LOG_LOCK, fcntl.LOCK_UN)
if runner is not None:
try:
fcntl.lockf(runner.output_lockfile, fcntl.LOCK_UN)
except OSError, e:
# already got closed?
pass
def set_play(callback, play): def set_play(callback, play):
''' used to notify callback plugins of context ''' ''' used to notify callback plugins of context '''
@ -101,9 +113,9 @@ def set_task(callback, task):
for callback_plugin in callback_plugins: for callback_plugin in callback_plugins:
callback_plugin.task = task callback_plugin.task = task
def display(msg, color=None, stderr=False, screen_only=False, log_only=False): def display(msg, color=None, stderr=False, screen_only=False, log_only=False, runner=None):
# prevent a very rare case of interlaced multiprocess I/O # prevent a very rare case of interlaced multiprocess I/O
log_flock() log_flock(runner)
msg2 = msg msg2 = msg
if color: if color:
msg2 = stringc(msg, color) msg2 = stringc(msg, color)
@ -120,7 +132,7 @@ def display(msg, color=None, stderr=False, screen_only=False, log_only=False):
logger.error(msg) logger.error(msg)
else: else:
logger.info(msg) logger.info(msg)
log_unflock() log_unflock(runner)
def call_callback_module(method_name, *args, **kwargs): def call_callback_module(method_name, *args, **kwargs):
@ -346,7 +358,7 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
def on_unreachable(self, host, res): def on_unreachable(self, host, res):
if type(res) == dict: if type(res) == dict:
res = res.get('msg','') res = res.get('msg','')
display("%s | FAILED => %s" % (host, res), stderr=True, color='red') display("%s | FAILED => %s" % (host, res), stderr=True, color='red', runner=self.runner)
if self.options.tree: if self.options.tree:
utils.write_tree_file( utils.write_tree_file(
self.options.tree, host, self.options.tree, host,
@ -355,15 +367,15 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
super(CliRunnerCallbacks, self).on_unreachable(host, res) super(CliRunnerCallbacks, self).on_unreachable(host, res)
def on_skipped(self, host, item=None): def on_skipped(self, host, item=None):
display("%s | skipped" % (host)) display("%s | skipped" % (host), runner=self.runner)
super(CliRunnerCallbacks, self).on_skipped(host, item) super(CliRunnerCallbacks, self).on_skipped(host, item)
def on_error(self, host, err): def on_error(self, host, err):
display("err: [%s] => %s\n" % (host, err), stderr=True) display("err: [%s] => %s\n" % (host, err), stderr=True, runner=self.runner)
super(CliRunnerCallbacks, self).on_error(host, err) super(CliRunnerCallbacks, self).on_error(host, err)
def on_no_hosts(self): def on_no_hosts(self):
display("no hosts matched\n", stderr=True) display("no hosts matched\n", stderr=True, runner=self.runner)
super(CliRunnerCallbacks, self).on_no_hosts() super(CliRunnerCallbacks, self).on_no_hosts()
def on_async_poll(self, host, res, jid, clock): def on_async_poll(self, host, res, jid, clock):
@ -371,27 +383,27 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
self._async_notified[jid] = clock + 1 self._async_notified[jid] = clock + 1
if self._async_notified[jid] > clock: if self._async_notified[jid] > clock:
self._async_notified[jid] = clock self._async_notified[jid] = clock
display("<job %s> polling, %ss remaining" % (jid, clock)) display("<job %s> polling, %ss remaining" % (jid, clock), runner=self.runner)
super(CliRunnerCallbacks, self).on_async_poll(host, res, jid, clock) super(CliRunnerCallbacks, self).on_async_poll(host, res, jid, clock)
def on_async_ok(self, host, res, jid): def on_async_ok(self, host, res, jid):
display("<job %s> finished on %s => %s"%(jid, host, utils.jsonify(res,format=True))) display("<job %s> finished on %s => %s"%(jid, host, utils.jsonify(res,format=True)), runner=self.runner)
super(CliRunnerCallbacks, self).on_async_ok(host, res, jid) super(CliRunnerCallbacks, self).on_async_ok(host, res, jid)
def on_async_failed(self, host, res, jid): def on_async_failed(self, host, res, jid):
display("<job %s> FAILED on %s => %s"%(jid, host, utils.jsonify(res,format=True)), color='red', stderr=True) display("<job %s> FAILED on %s => %s"%(jid, host, utils.jsonify(res,format=True)), color='red', stderr=True, runner=self.runner)
super(CliRunnerCallbacks, self).on_async_failed(host,res,jid) super(CliRunnerCallbacks, self).on_async_failed(host,res,jid)
def _on_any(self, host, result): def _on_any(self, host, result):
result2 = result.copy() result2 = result.copy()
result2.pop('invocation', None) result2.pop('invocation', None)
(msg, color) = host_report_msg(host, self.options.module_name, result2, self.options.one_line) (msg, color) = host_report_msg(host, self.options.module_name, result2, self.options.one_line)
display(msg, color=color) display(msg, color=color, runner=self.runner)
if self.options.tree: if self.options.tree:
utils.write_tree_file(self.options.tree, host, utils.jsonify(result2,format=True)) utils.write_tree_file(self.options.tree, host, utils.jsonify(result2,format=True))
def on_file_diff(self, host, diff): def on_file_diff(self, host, diff):
display(utils.get_diff(diff)) display(utils.get_diff(diff), runner=self.runner)
super(CliRunnerCallbacks, self).on_file_diff(host, diff) super(CliRunnerCallbacks, self).on_file_diff(host, diff)
######################################################################## ########################################################################
@ -412,11 +424,12 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
msg = "fatal: [%s] => (item=%s) => %s" % (host, item, results) msg = "fatal: [%s] => (item=%s) => %s" % (host, item, results)
else: else:
msg = "fatal: [%s] => %s" % (host, results) msg = "fatal: [%s] => %s" % (host, results)
display(msg, color='red') display(msg, color='red', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_unreachable(host, results) super(PlaybookRunnerCallbacks, self).on_unreachable(host, results)
def on_failed(self, host, results, ignore_errors=False): def on_failed(self, host, results, ignore_errors=False):
results2 = results.copy() results2 = results.copy()
results2.pop('invocation', None) results2.pop('invocation', None)
@ -433,21 +446,22 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
msg = "failed: [%s] => (item=%s) => %s" % (host, item, utils.jsonify(results2)) msg = "failed: [%s] => (item=%s) => %s" % (host, item, utils.jsonify(results2))
else: else:
msg = "failed: [%s] => %s" % (host, utils.jsonify(results2)) msg = "failed: [%s] => %s" % (host, utils.jsonify(results2))
display(msg, color='red') display(msg, color='red', runner=self.runner)
if stderr: if stderr:
display("stderr: %s" % stderr, color='red') display("stderr: %s" % stderr, color='red', runner=self.runner)
if stdout: if stdout:
display("stdout: %s" % stdout, color='red') display("stdout: %s" % stdout, color='red', runner=self.runner)
if returned_msg: if returned_msg:
display("msg: %s" % returned_msg, color='red') display("msg: %s" % returned_msg, color='red', runner=self.runner)
if not parsed and module_msg: if not parsed and module_msg:
display("invalid output was: %s" % module_msg, color='red') display("invalid output was: %s" % module_msg, color='red', runner=self.runner)
if ignore_errors: if ignore_errors:
display("...ignoring", color='cyan') display("...ignoring", color='cyan', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_failed(host, results, ignore_errors=ignore_errors) super(PlaybookRunnerCallbacks, self).on_failed(host, results, ignore_errors=ignore_errors)
def on_ok(self, host, host_result): def on_ok(self, host, host_result):
item = host_result.get('item', None) item = host_result.get('item', None)
host_result2 = host_result.copy() host_result2 = host_result.copy()
@ -477,9 +491,9 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
if msg != '': if msg != '':
if not changed: if not changed:
display(msg, color='green') display(msg, color='green', runner=self.runner)
else: else:
display(msg, color='yellow') display(msg, color='yellow', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_ok(host, host_result) super(PlaybookRunnerCallbacks, self).on_ok(host, host_result)
def on_error(self, host, err): def on_error(self, host, err):
@ -491,7 +505,7 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
else: else:
msg = "err: [%s] => %s" % (host, err) msg = "err: [%s] => %s" % (host, err)
display(msg, color='red', stderr=True) display(msg, color='red', stderr=True, runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_error(host, err) super(PlaybookRunnerCallbacks, self).on_error(host, err)
def on_skipped(self, host, item=None): def on_skipped(self, host, item=None):
@ -500,11 +514,11 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
msg = "skipping: [%s] => (item=%s)" % (host, item) msg = "skipping: [%s] => (item=%s)" % (host, item)
else: else:
msg = "skipping: [%s]" % host msg = "skipping: [%s]" % host
display(msg, color='cyan') display(msg, color='cyan', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_skipped(host, item) super(PlaybookRunnerCallbacks, self).on_skipped(host, item)
def on_no_hosts(self): def on_no_hosts(self):
display("FATAL: no hosts matched or all hosts have already failed -- aborting\n", color='red') display("FATAL: no hosts matched or all hosts have already failed -- aborting\n", color='red', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_no_hosts() super(PlaybookRunnerCallbacks, self).on_no_hosts()
def on_async_poll(self, host, res, jid, clock): def on_async_poll(self, host, res, jid, clock):
@ -513,21 +527,21 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
if self._async_notified[jid] > clock: if self._async_notified[jid] > clock:
self._async_notified[jid] = clock self._async_notified[jid] = clock
msg = "<job %s> polling, %ss remaining"%(jid, clock) msg = "<job %s> polling, %ss remaining"%(jid, clock)
display(msg, color='cyan') display(msg, color='cyan', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_async_poll(host,res,jid,clock) super(PlaybookRunnerCallbacks, self).on_async_poll(host,res,jid,clock)
def on_async_ok(self, host, res, jid): def on_async_ok(self, host, res, jid):
msg = "<job %s> finished on %s"%(jid, host) msg = "<job %s> finished on %s"%(jid, host)
display(msg, color='cyan') display(msg, color='cyan', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_async_ok(host, res, jid) super(PlaybookRunnerCallbacks, self).on_async_ok(host, res, jid)
def on_async_failed(self, host, res, jid): def on_async_failed(self, host, res, jid):
msg = "<job %s> FAILED on %s" % (jid, host) msg = "<job %s> FAILED on %s" % (jid, host)
display(msg, color='red', stderr=True) display(msg, color='red', stderr=True, runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_async_failed(host,res,jid) super(PlaybookRunnerCallbacks, self).on_async_failed(host,res,jid)
def on_file_diff(self, host, diff): def on_file_diff(self, host, diff):
display(utils.get_diff(diff)) display(utils.get_diff(diff), runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_file_diff(host, diff) super(PlaybookRunnerCallbacks, self).on_file_diff(host, diff)
######################################################################## ########################################################################

View file

@ -51,6 +51,9 @@ except ImportError:
HAS_ATFORK=False HAS_ATFORK=False
multiprocessing_runner = None multiprocessing_runner = None
OUTPUT_LOCKFILE = tempfile.TemporaryFile()
PROCESS_LOCKFILE = tempfile.TemporaryFile()
################################################ ################################################
@ -134,8 +137,9 @@ class Runner(object):
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR # ex. False error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR # ex. False
): ):
# used to lock multiprocess inputs that wish to share stdin # used to lock multiprocess inputs and outputs at various levels
self.lockfile = tempfile.NamedTemporaryFile() self.output_lockfile = OUTPUT_LOCKFILE
self.process_lockfile = PROCESS_LOCKFILE
if not complex_args: if not complex_args:
complex_args = {} complex_args = {}
@ -884,9 +888,6 @@ class Runner(object):
else: else:
results = [ self._executor(h, None) for h in hosts ] results = [ self._executor(h, None) for h in hosts ]
self.lockfile.close()
return self._partition_results(results) return self._partition_results(results)
# ***************************************************** # *****************************************************

View file

@ -72,9 +72,9 @@ class MyAddPolicy(object):
if C.HOST_KEY_CHECKING: if C.HOST_KEY_CHECKING:
KEY_LOCK = self.runner.lockfile fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_EX)
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX) fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_EX)
old_stdin = sys.stdin old_stdin = sys.stdin
sys.stdin = self.runner._new_stdin sys.stdin = self.runner._new_stdin
fingerprint = hexlify(key.get_fingerprint()) fingerprint = hexlify(key.get_fingerprint())
@ -86,10 +86,12 @@ class MyAddPolicy(object):
inp = raw_input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint)) inp = raw_input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint))
sys.stdin = old_stdin sys.stdin = old_stdin
if inp not in ['yes','y','']: if inp not in ['yes','y','']:
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN) fcntl.flock(self.runner.output_lockfile, fcntl.LOCK_UN)
fcntl.flock(self.runner.process_lockfile, fcntl.LOCK_UN)
raise errors.AnsibleError("host connection rejected by user") raise errors.AnsibleError("host connection rejected by user")
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN) fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_UN)
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
key._added_by_ansible_this_time = True key._added_by_ansible_this_time = True
@ -257,22 +259,23 @@ class Connection(object):
except IOError: except IOError:
raise errors.AnsibleError("failed to transfer file from %s" % in_path) raise errors.AnsibleError("failed to transfer file from %s" % in_path)
def _any_keys_added(self):
added_any = False
for hostname, keys in self.ssh._host_keys.iteritems():
for keytype, key in keys.iteritems():
added_this_time = getattr(key, '_added_by_ansible_this_time', False)
if added_this_time:
return True
return False
def _save_ssh_host_keys(self, filename): def _save_ssh_host_keys(self, filename):
''' '''
not using the paramiko save_ssh_host_keys function as we want to add new SSH keys at the bottom so folks not using the paramiko save_ssh_host_keys function as we want to add new SSH keys at the bottom so folks
don't complain about it :) don't complain about it :)
''' '''
added_any = False if not self._any_keys_added():
for hostname, keys in self.ssh._host_keys.iteritems(): return False
for keytype, key in keys.iteritems():
added_this_time = getattr(key, '_added_by_ansible_this_time', False)
if added_this_time:
added_any = True
break
if not added_any:
return
path = os.path.expanduser("~/.ssh") path = os.path.expanduser("~/.ssh")
if not os.path.exists(path): if not os.path.exists(path):
@ -300,23 +303,22 @@ class Connection(object):
if self.sftp is not None: if self.sftp is not None:
self.sftp.close() self.sftp.close()
# add any new SSH host keys if self._any_keys_added():
lockfile = self.keyfile.replace("known_hosts",".known_hosts.lock") # add any new SSH host keys -- warning -- this could be slow
KEY_LOCK = open(lockfile, 'w') lockfile = self.keyfile.replace("known_hosts",".known_hosts.lock")
fcntl.flock(KEY_LOCK, fcntl.LOCK_EX) KEY_LOCK = open(lockfile, 'w')
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
try: try:
# just in case any were added recently # just in case any were added recently
self.ssh.load_system_host_keys() self.ssh.load_system_host_keys()
self.ssh._host_keys.update(self.ssh._system_host_keys) self.ssh._host_keys.update(self.ssh._system_host_keys)
#self.ssh.save_host_keys(self.keyfile) self._save_ssh_host_keys(self.keyfile)
self._save_ssh_host_keys(self.keyfile) except:
except: # unable to save keys, including scenario when key was invalid
# unable to save keys, including scenario when key was invalid # and caught earlier
# and caught earlier traceback.print_exc()
traceback.print_exc() pass
pass fcntl.lockf(KEY_LOCK, fcntl.LOCK_UN)
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
self.ssh.close() self.ssh.close()

View file

@ -131,8 +131,9 @@ class Connection(object):
if C.HOST_KEY_CHECKING and not_in_host_file: if C.HOST_KEY_CHECKING and not_in_host_file:
# lock around the initial SSH connectivity so the user prompt about whether to add # lock around the initial SSH connectivity so the user prompt about whether to add
# the host to known hosts is not intermingled with multiprocess output. # the host to known hosts is not intermingled with multiprocess output.
KEY_LOCK = self.runner.lockfile fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_EX)
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX) fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_EX)
try: try:
@ -191,8 +192,8 @@ class Connection(object):
if C.HOST_KEY_CHECKING and not_in_host_file: if C.HOST_KEY_CHECKING and not_in_host_file:
# lock around the initial SSH connectivity so the user prompt about whether to add # lock around the initial SSH connectivity so the user prompt about whether to add
# the host to known hosts is not intermingled with multiprocess output. # the host to known hosts is not intermingled with multiprocess output.
KEY_LOCK = self.runner.lockfile fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_UN)
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX) fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
if p.returncode != 0 and stderr.find('Bad configuration option: ControlPersist') != -1: if p.returncode != 0 and stderr.find('Bad configuration option: ControlPersist') != -1:
raise errors.AnsibleError('using -c ssh on certain older ssh versions may not support ControlPersist, set ANSIBLE_SSH_ARGS="" (or ansible_ssh_args in the config file) before running again') raise errors.AnsibleError('using -c ssh on certain older ssh versions may not support ControlPersist, set ANSIBLE_SSH_ARGS="" (or ansible_ssh_args in the config file) before running again')