Improve interlaced output prevention when asking for host key approval.

This commit is contained in:
Michael DeHaan 2013-07-04 18:17:45 -04:00
parent c55adc9ac9
commit 2cb7c30834
4 changed files with 91 additions and 73 deletions

View file

@ -83,11 +83,23 @@ def log_lockfile():
LOG_LOCK = open(log_lockfile(), 'w')
def log_flock():
fcntl.flock(LOG_LOCK, fcntl.LOCK_EX)
def log_flock(runner):
fcntl.lockf(LOG_LOCK, fcntl.LOCK_EX)
if runner is not None:
try:
fcntl.lockf(runner.output_lockfile, fcntl.LOCK_EX)
except OSError, e:
# already got closed?
pass
def log_unflock():
fcntl.flock(LOG_LOCK, fcntl.LOCK_UN)
def log_unflock(runner):
fcntl.lockf(LOG_LOCK, fcntl.LOCK_UN)
if runner is not None:
try:
fcntl.lockf(runner.output_lockfile, fcntl.LOCK_UN)
except OSError, e:
# already got closed?
pass
def set_play(callback, play):
''' used to notify callback plugins of context '''
@ -101,9 +113,9 @@ def set_task(callback, task):
for callback_plugin in callback_plugins:
callback_plugin.task = task
def display(msg, color=None, stderr=False, screen_only=False, log_only=False):
def display(msg, color=None, stderr=False, screen_only=False, log_only=False, runner=None):
# prevent a very rare case of interlaced multiprocess I/O
log_flock()
log_flock(runner)
msg2 = msg
if color:
msg2 = stringc(msg, color)
@ -120,7 +132,7 @@ def display(msg, color=None, stderr=False, screen_only=False, log_only=False):
logger.error(msg)
else:
logger.info(msg)
log_unflock()
log_unflock(runner)
def call_callback_module(method_name, *args, **kwargs):
@ -346,7 +358,7 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
def on_unreachable(self, host, res):
if type(res) == dict:
res = res.get('msg','')
display("%s | FAILED => %s" % (host, res), stderr=True, color='red')
display("%s | FAILED => %s" % (host, res), stderr=True, color='red', runner=self.runner)
if self.options.tree:
utils.write_tree_file(
self.options.tree, host,
@ -355,15 +367,15 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
super(CliRunnerCallbacks, self).on_unreachable(host, res)
def on_skipped(self, host, item=None):
display("%s | skipped" % (host))
display("%s | skipped" % (host), runner=self.runner)
super(CliRunnerCallbacks, self).on_skipped(host, item)
def on_error(self, host, err):
display("err: [%s] => %s\n" % (host, err), stderr=True)
display("err: [%s] => %s\n" % (host, err), stderr=True, runner=self.runner)
super(CliRunnerCallbacks, self).on_error(host, err)
def on_no_hosts(self):
display("no hosts matched\n", stderr=True)
display("no hosts matched\n", stderr=True, runner=self.runner)
super(CliRunnerCallbacks, self).on_no_hosts()
def on_async_poll(self, host, res, jid, clock):
@ -371,27 +383,27 @@ class CliRunnerCallbacks(DefaultRunnerCallbacks):
self._async_notified[jid] = clock + 1
if self._async_notified[jid] > clock:
self._async_notified[jid] = clock
display("<job %s> polling, %ss remaining" % (jid, clock))
display("<job %s> polling, %ss remaining" % (jid, clock), runner=self.runner)
super(CliRunnerCallbacks, self).on_async_poll(host, res, jid, clock)
def on_async_ok(self, host, res, jid):
display("<job %s> finished on %s => %s"%(jid, host, utils.jsonify(res,format=True)))
display("<job %s> finished on %s => %s"%(jid, host, utils.jsonify(res,format=True)), runner=self.runner)
super(CliRunnerCallbacks, self).on_async_ok(host, res, jid)
def on_async_failed(self, host, res, jid):
display("<job %s> FAILED on %s => %s"%(jid, host, utils.jsonify(res,format=True)), color='red', stderr=True)
display("<job %s> FAILED on %s => %s"%(jid, host, utils.jsonify(res,format=True)), color='red', stderr=True, runner=self.runner)
super(CliRunnerCallbacks, self).on_async_failed(host,res,jid)
def _on_any(self, host, result):
result2 = result.copy()
result2.pop('invocation', None)
(msg, color) = host_report_msg(host, self.options.module_name, result2, self.options.one_line)
display(msg, color=color)
display(msg, color=color, runner=self.runner)
if self.options.tree:
utils.write_tree_file(self.options.tree, host, utils.jsonify(result2,format=True))
def on_file_diff(self, host, diff):
display(utils.get_diff(diff))
display(utils.get_diff(diff), runner=self.runner)
super(CliRunnerCallbacks, self).on_file_diff(host, diff)
########################################################################
@ -412,11 +424,12 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
msg = "fatal: [%s] => (item=%s) => %s" % (host, item, results)
else:
msg = "fatal: [%s] => %s" % (host, results)
display(msg, color='red')
display(msg, color='red', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_unreachable(host, results)
def on_failed(self, host, results, ignore_errors=False):
results2 = results.copy()
results2.pop('invocation', None)
@ -433,21 +446,22 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
msg = "failed: [%s] => (item=%s) => %s" % (host, item, utils.jsonify(results2))
else:
msg = "failed: [%s] => %s" % (host, utils.jsonify(results2))
display(msg, color='red')
display(msg, color='red', runner=self.runner)
if stderr:
display("stderr: %s" % stderr, color='red')
display("stderr: %s" % stderr, color='red', runner=self.runner)
if stdout:
display("stdout: %s" % stdout, color='red')
display("stdout: %s" % stdout, color='red', runner=self.runner)
if returned_msg:
display("msg: %s" % returned_msg, color='red')
display("msg: %s" % returned_msg, color='red', runner=self.runner)
if not parsed and module_msg:
display("invalid output was: %s" % module_msg, color='red')
display("invalid output was: %s" % module_msg, color='red', runner=self.runner)
if ignore_errors:
display("...ignoring", color='cyan')
display("...ignoring", color='cyan', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_failed(host, results, ignore_errors=ignore_errors)
def on_ok(self, host, host_result):
item = host_result.get('item', None)
host_result2 = host_result.copy()
@ -477,9 +491,9 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
if msg != '':
if not changed:
display(msg, color='green')
display(msg, color='green', runner=self.runner)
else:
display(msg, color='yellow')
display(msg, color='yellow', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_ok(host, host_result)
def on_error(self, host, err):
@ -491,7 +505,7 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
else:
msg = "err: [%s] => %s" % (host, err)
display(msg, color='red', stderr=True)
display(msg, color='red', stderr=True, runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_error(host, err)
def on_skipped(self, host, item=None):
@ -500,11 +514,11 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
msg = "skipping: [%s] => (item=%s)" % (host, item)
else:
msg = "skipping: [%s]" % host
display(msg, color='cyan')
display(msg, color='cyan', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_skipped(host, item)
def on_no_hosts(self):
display("FATAL: no hosts matched or all hosts have already failed -- aborting\n", color='red')
display("FATAL: no hosts matched or all hosts have already failed -- aborting\n", color='red', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_no_hosts()
def on_async_poll(self, host, res, jid, clock):
@ -513,21 +527,21 @@ class PlaybookRunnerCallbacks(DefaultRunnerCallbacks):
if self._async_notified[jid] > clock:
self._async_notified[jid] = clock
msg = "<job %s> polling, %ss remaining"%(jid, clock)
display(msg, color='cyan')
display(msg, color='cyan', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_async_poll(host,res,jid,clock)
def on_async_ok(self, host, res, jid):
msg = "<job %s> finished on %s"%(jid, host)
display(msg, color='cyan')
display(msg, color='cyan', runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_async_ok(host, res, jid)
def on_async_failed(self, host, res, jid):
msg = "<job %s> FAILED on %s" % (jid, host)
display(msg, color='red', stderr=True)
display(msg, color='red', stderr=True, runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_async_failed(host,res,jid)
def on_file_diff(self, host, diff):
display(utils.get_diff(diff))
display(utils.get_diff(diff), runner=self.runner)
super(PlaybookRunnerCallbacks, self).on_file_diff(host, diff)
########################################################################

View file

@ -51,6 +51,9 @@ except ImportError:
HAS_ATFORK=False
multiprocessing_runner = None
OUTPUT_LOCKFILE = tempfile.TemporaryFile()
PROCESS_LOCKFILE = tempfile.TemporaryFile()
################################################
@ -134,8 +137,9 @@ class Runner(object):
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR # ex. False
):
# used to lock multiprocess inputs that wish to share stdin
self.lockfile = tempfile.NamedTemporaryFile()
# used to lock multiprocess inputs and outputs at various levels
self.output_lockfile = OUTPUT_LOCKFILE
self.process_lockfile = PROCESS_LOCKFILE
if not complex_args:
complex_args = {}
@ -884,9 +888,6 @@ class Runner(object):
else:
results = [ self._executor(h, None) for h in hosts ]
self.lockfile.close()
return self._partition_results(results)
# *****************************************************

View file

@ -72,9 +72,9 @@ class MyAddPolicy(object):
if C.HOST_KEY_CHECKING:
KEY_LOCK = self.runner.lockfile
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_EX)
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_EX)
old_stdin = sys.stdin
sys.stdin = self.runner._new_stdin
fingerprint = hexlify(key.get_fingerprint())
@ -86,10 +86,12 @@ class MyAddPolicy(object):
inp = raw_input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint))
sys.stdin = old_stdin
if inp not in ['yes','y','']:
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
fcntl.flock(self.runner.output_lockfile, fcntl.LOCK_UN)
fcntl.flock(self.runner.process_lockfile, fcntl.LOCK_UN)
raise errors.AnsibleError("host connection rejected by user")
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_UN)
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
key._added_by_ansible_this_time = True
@ -257,22 +259,23 @@ class Connection(object):
except IOError:
raise errors.AnsibleError("failed to transfer file from %s" % in_path)
def _any_keys_added(self):
added_any = False
for hostname, keys in self.ssh._host_keys.iteritems():
for keytype, key in keys.iteritems():
added_this_time = getattr(key, '_added_by_ansible_this_time', False)
if added_this_time:
return True
return False
def _save_ssh_host_keys(self, filename):
'''
not using the paramiko save_ssh_host_keys function as we want to add new SSH keys at the bottom so folks
don't complain about it :)
'''
added_any = False
for hostname, keys in self.ssh._host_keys.iteritems():
for keytype, key in keys.iteritems():
added_this_time = getattr(key, '_added_by_ansible_this_time', False)
if added_this_time:
added_any = True
break
if not added_any:
return
if not self._any_keys_added():
return False
path = os.path.expanduser("~/.ssh")
if not os.path.exists(path):
@ -300,23 +303,22 @@ class Connection(object):
if self.sftp is not None:
self.sftp.close()
# add any new SSH host keys
lockfile = self.keyfile.replace("known_hosts",".known_hosts.lock")
KEY_LOCK = open(lockfile, 'w')
fcntl.flock(KEY_LOCK, fcntl.LOCK_EX)
try:
# just in case any were added recently
self.ssh.load_system_host_keys()
self.ssh._host_keys.update(self.ssh._system_host_keys)
#self.ssh.save_host_keys(self.keyfile)
self._save_ssh_host_keys(self.keyfile)
except:
# unable to save keys, including scenario when key was invalid
# and caught earlier
traceback.print_exc()
pass
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
if self._any_keys_added():
# add any new SSH host keys -- warning -- this could be slow
lockfile = self.keyfile.replace("known_hosts",".known_hosts.lock")
KEY_LOCK = open(lockfile, 'w')
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
try:
# just in case any were added recently
self.ssh.load_system_host_keys()
self.ssh._host_keys.update(self.ssh._system_host_keys)
self._save_ssh_host_keys(self.keyfile)
except:
# unable to save keys, including scenario when key was invalid
# and caught earlier
traceback.print_exc()
pass
fcntl.lockf(KEY_LOCK, fcntl.LOCK_UN)
self.ssh.close()

View file

@ -131,8 +131,9 @@ class Connection(object):
if C.HOST_KEY_CHECKING and not_in_host_file:
# lock around the initial SSH connectivity so the user prompt about whether to add
# the host to known hosts is not intermingled with multiprocess output.
KEY_LOCK = self.runner.lockfile
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_EX)
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_EX)
try:
@ -191,8 +192,8 @@ class Connection(object):
if C.HOST_KEY_CHECKING and not_in_host_file:
# lock around the initial SSH connectivity so the user prompt about whether to add
# the host to known hosts is not intermingled with multiprocess output.
KEY_LOCK = self.runner.lockfile
fcntl.lockf(KEY_LOCK, fcntl.LOCK_EX)
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_UN)
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
if p.returncode != 0 and stderr.find('Bad configuration option: ControlPersist') != -1:
raise errors.AnsibleError('using -c ssh on certain older ssh versions may not support ControlPersist, set ANSIBLE_SSH_ARGS="" (or ansible_ssh_args in the config file) before running again')