Makes host key checking the default behavior but can be disabled in ansible.cfg or by environment variable.

This commit is contained in:
Michael DeHaan 2013-07-03 16:47:20 -04:00
parent 4407ca8031
commit 9db4f7a9a6
5 changed files with 117 additions and 11 deletions

View file

@ -22,6 +22,9 @@ sudo_user = root
transport = paramiko
remote_port = 22
# uncomment this to disable SSH key host checking
#host_checking = False
# change this for alternative sudo implementations
sudo_exe = sudo

View file

@ -20,7 +20,25 @@ import pwd
import sys
import ConfigParser
def get_config(p, section, key, env_var, default):
# copied from utils, avoid circular reference fun :)
def mk_boolean(value):
val = str(value)
if val.lower() in [ "true", "t", "y", "1", "yes" ]:
return True
else:
return False
def get_config(p, section, key, env_var, default, boolean=False, integer=False):
''' return a configuration variable with casting '''
value = _get_config(p, section, key, env_var, default)
if boolean:
return mk_boolean(value)
if integer:
return int(value)
return value
def _get_config(p, section, key, env_var, default, boolean=True):
''' helper function for get_config '''
if env_var is not None:
value = os.environ.get(env_var, None)
if value is not None:
@ -81,20 +99,20 @@ DEFAULT_MODULE_LANG = get_config(p, DEFAULTS, 'module_lang', 'ANSIBLE
DEFAULT_TIMEOUT = get_config(p, DEFAULTS, 'timeout', 'ANSIBLE_TIMEOUT', 10)
DEFAULT_POLL_INTERVAL = get_config(p, DEFAULTS, 'poll_interval', 'ANSIBLE_POLL_INTERVAL', 15)
DEFAULT_REMOTE_USER = get_config(p, DEFAULTS, 'remote_user', 'ANSIBLE_REMOTE_USER', active_user)
DEFAULT_ASK_PASS = get_config(p, DEFAULTS, 'ask_pass', 'ANSIBLE_ASK_PASS', False)
DEFAULT_ASK_PASS = get_config(p, DEFAULTS, 'ask_pass', 'ANSIBLE_ASK_PASS', False, boolean=True)
DEFAULT_PRIVATE_KEY_FILE = shell_expand_path(get_config(p, DEFAULTS, 'private_key_file', 'ANSIBLE_PRIVATE_KEY_FILE', None))
DEFAULT_SUDO_USER = get_config(p, DEFAULTS, 'sudo_user', 'ANSIBLE_SUDO_USER', 'root')
DEFAULT_ASK_SUDO_PASS = get_config(p, DEFAULTS, 'ask_sudo_pass', 'ANSIBLE_ASK_SUDO_PASS', False)
DEFAULT_ASK_SUDO_PASS = get_config(p, DEFAULTS, 'ask_sudo_pass', 'ANSIBLE_ASK_SUDO_PASS', False, boolean=True)
DEFAULT_REMOTE_PORT = int(get_config(p, DEFAULTS, 'remote_port', 'ANSIBLE_REMOTE_PORT', 22))
DEFAULT_TRANSPORT = get_config(p, DEFAULTS, 'transport', 'ANSIBLE_TRANSPORT', 'paramiko')
DEFAULT_SCP_IF_SSH = get_config(p, 'ssh_connection', 'scp_if_ssh', 'ANSIBLE_SCP_IF_SSH', False)
DEFAULT_SCP_IF_SSH = get_config(p, 'ssh_connection', 'scp_if_ssh', 'ANSIBLE_SCP_IF_SSH', False, boolean=True)
DEFAULT_MANAGED_STR = get_config(p, DEFAULTS, 'ansible_managed', None, 'Ansible managed: {file} modified on %Y-%m-%d %H:%M:%S by {uid} on {host}')
DEFAULT_SYSLOG_FACILITY = get_config(p, DEFAULTS, 'syslog_facility', 'ANSIBLE_SYSLOG_FACILITY', 'LOG_USER')
DEFAULT_KEEP_REMOTE_FILES = get_config(p, DEFAULTS, 'keep_remote_files', 'ANSIBLE_KEEP_REMOTE_FILES', '0')
DEFAULT_SUDO_EXE = get_config(p, DEFAULTS, 'sudo_exe', 'ANSIBLE_SUDO_EXE', 'sudo')
DEFAULT_SUDO_FLAGS = get_config(p, DEFAULTS, 'sudo_flags', 'ANSIBLE_SUDO_FLAGS', '-H')
DEFAULT_HASH_BEHAVIOUR = get_config(p, DEFAULTS, 'hash_behaviour', 'ANSIBLE_HASH_BEHAVIOUR', 'replace')
DEFAULT_LEGACY_PLAYBOOK_VARIABLES = get_config(p, DEFAULTS, 'legacy_playbook_variables', 'ANSIBLE_LEGACY_PLAYBOOK_VARIABLES', 'yes')
DEFAULT_LEGACY_PLAYBOOK_VARIABLES = get_config(p, DEFAULTS, 'legacy_playbook_variables', 'ANSIBLE_LEGACY_PLAYBOOK_VARIABLES', True, boolean=True)
DEFAULT_JINJA2_EXTENSIONS = get_config(p, DEFAULTS, 'jinja2_extensions', 'ANSIBLE_JINJA2_EXTENSIONS', None)
DEFAULT_EXECUTABLE = get_config(p, DEFAULTS, 'executable', 'ANSIBLE_EXECUTABLE', '/bin/sh')
@ -110,7 +128,8 @@ ANSIBLE_NOCOWS = get_config(p, DEFAULTS, 'nocows', 'ANSIBLE_NOCO
ANSIBLE_SSH_ARGS = get_config(p, 'ssh_connection', 'ssh_args', 'ANSIBLE_SSH_ARGS', None)
ZEROMQ_PORT = int(get_config(p, 'fireball', 'zeromq_port', 'ANSIBLE_ZEROMQ_PORT', 5099))
DEFAULT_UNDEFINED_VAR_BEHAVIOR = get_config(p, DEFAULTS, 'error_on_undefined_vars', 'ANSIBLE_ERROR_ON_UNDEFINED_VARS', False)
DEFAULT_UNDEFINED_VAR_BEHAVIOR = get_config(p, DEFAULTS, 'error_on_undefined_vars', 'ANSIBLE_ERROR_ON_UNDEFINED_VARS', False, boolean=True)
HOST_KEY_CHECKING = get_config(p, DEFAULTS, 'host_key_checking', 'ANSIBLE_HOST_KEY_CHECKING', True, boolean=True)
# non-configurable things
DEFAULT_SUDO_PASS = None

View file

@ -21,9 +21,14 @@ import pipes
import socket
import random
import logging
import traceback
import fcntl
import sys
from binascii import hexlify
from ansible.callbacks import vvv
from ansible import errors
from ansible import utils
from ansible import constants as C
# prevent paramiko warning noise -- see http://stackoverflow.com/questions/3920502/
HAVE_PARAMIKO=False
@ -36,12 +41,29 @@ with warnings.catch_warnings():
except ImportError:
pass
class MyAutoAddPolicy(object):
"""
Modified version of AutoAddPolicy in paramiko so we can determine when keys are added.
Policy for automatically adding the hostname and new host key to the
local L{HostKeys} object, and saving it. This is used by L{SSHClient}.
"""
def missing_host_key(self, client, hostname, key):
key._added_by_ansible_this_time = True
# existing implementation below:
client._host_keys.add(hostname, key.get_name(), key)
if client._host_keys_filename is not None:
client.save_host_keys(client._host_keys_filename)
# keep connection objects on a per host basis to avoid repeated attempts to reconnect
SSH_CONNECTION_CACHE = {}
SFTP_CONNECTION_CACHE = {}
class Connection(object):
''' SSH based connections with Paramiko '''
@ -76,7 +98,12 @@ class Connection(object):
vvv("ESTABLISH CONNECTION FOR USER: %s on PORT %s TO %s" % (self.user, self.port, self.host), host=self.host)
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.keyfile = os.path.expanduser("~/.ssh/known_hosts")
if C.HOST_KEY_CHECKING:
ssh.load_system_host_keys()
ssh.set_missing_host_key_policy(MyAutoAddPolicy())
allow_agent = True
if self.password is not None:
@ -188,6 +215,41 @@ class Connection(object):
except IOError:
raise errors.AnsibleError("failed to transfer file from %s" % in_path)
def _save_ssh_host_keys(self, filename):
'''
not using the paramiko save_ssh_host_keys function as we want to add new SSH keys at the bottom so folks
don't complain about it :)
'''
added_any = False
for hostname, keys in self.ssh._host_keys.iteritems():
for keytype, key in keys.iteritems():
added_this_time = getattr(key, '_added_by_ansible_this_time', False)
if added_this_time:
added_any = True
break
if not added_any:
return
path = os.path.expanduser("~/.ssh")
if not os.path.exists(path):
os.makedirs(path)
f = open(filename, 'w')
for hostname, keys in self.ssh._host_keys.iteritems():
for keytype, key in keys.iteritems():
# was f.write
added_this_time = getattr(key, '_added_by_ansible_this_time', False)
if not added_this_time:
f.write("%s %s %s\n" % (hostname, keytype, key.get_base64()))
for hostname, keys in self.ssh._host_keys.iteritems():
for keytype, key in keys.iteritems():
added_this_time = getattr(key, '_added_by_ansible_this_time', False)
if added_this_time:
f.write("%s %s %s\n" % (hostname, keytype, key.get_base64()))
f.close()
def close(self):
''' terminate the connection '''
cache_key = self._cache_key()
@ -195,5 +257,24 @@ class Connection(object):
SFTP_CONNECTION_CACHE.pop(cache_key, None)
if self.sftp is not None:
self.sftp.close()
# add any new SSH host keys
lockfile = self.keyfile.replace("known_hosts",".known_hosts.lock")
KEY_LOCK = open(lockfile, 'w')
fcntl.flock(KEY_LOCK, fcntl.LOCK_EX)
try:
# just in case any were added recently
self.ssh.load_system_host_keys()
self.ssh._host_keys.update(self.ssh._system_host_keys)
#self.ssh.save_host_keys(self.keyfile)
self._save_ssh_host_keys(self.keyfile)
except:
# unable to save keys, including scenario when key was invalid
# and caught earlier
traceback.print_exc()
pass
fcntl.flock(KEY_LOCK, fcntl.LOCK_UN)
self.ssh.close()

View file

@ -53,7 +53,10 @@ class Connection(object):
self.common_args += ["-o", "ControlMaster=auto",
"-o", "ControlPersist=60s",
"-o", "ControlPath=/tmp/ansible-ssh-%h-%p-%r"]
if not C.HOST_KEY_CHECKING:
self.common_args += ["-o", "StrictHostKeyChecking=no"]
if self.port is not None:
self.common_args += ["-o", "Port=%d" % (self.port)]
if self.private_key_file is not None:

View file

@ -20,7 +20,7 @@ import os.path
import sys
import glob
import imp
import ansible.constants as C
from ansible import constants as C
from ansible import errors
MODULE_CACHE = {}