Initial work to make paramiko connections work under v2

This commit is contained in:
James Cammarata 2015-04-24 02:47:56 -04:00
parent 8c08f1b302
commit 8574d40b98
7 changed files with 124 additions and 167 deletions

View file

@ -48,8 +48,9 @@ class ConnectionInformation:
self.remote_addr = None self.remote_addr = None
self.remote_user = None self.remote_user = None
self.password = passwords.get('conn_pass','') self.password = passwords.get('conn_pass','')
self.port = None self.port = 22
self.private_key_file = None self.private_key_file = C.DEFAULT_PRIVATE_KEY_FILE
self.timeout = C.DEFAULT_TIMEOUT
# privilege escalation # privilege escalation
self.become = None self.become = None
@ -119,9 +120,7 @@ class ConnectionInformation:
self.connection = options.connection self.connection = options.connection
self.remote_user = options.remote_user self.remote_user = options.remote_user
#if 'port' in options and options.port is not None: self.private_key_file = options.private_key_file
# self.port = options.port
self.private_key_file = None
# privilege escalation # privilege escalation
self.become = options.become self.become = options.become

View file

@ -51,7 +51,7 @@ class WorkerProcess(multiprocessing.Process):
for reading later. for reading later.
''' '''
def __init__(self, tqm, main_q, rslt_q, loader, new_stdin): def __init__(self, tqm, main_q, rslt_q, loader):
# takes a task queue manager as the sole param: # takes a task queue manager as the sole param:
self._main_q = main_q self._main_q = main_q
@ -59,13 +59,10 @@ class WorkerProcess(multiprocessing.Process):
self._loader = loader self._loader = loader
# dupe stdin, if we have one # dupe stdin, if we have one
self._new_stdin = sys.stdin
try: try:
fileno = sys.stdin.fileno() fileno = sys.stdin.fileno()
except ValueError: if fileno is not None:
fileno = None
self._new_stdin = new_stdin
if not new_stdin and fileno is not None:
try: try:
self._new_stdin = os.fdopen(os.dup(fileno)) self._new_stdin = os.fdopen(os.dup(fileno))
except OSError, e: except OSError, e:
@ -73,9 +70,9 @@ class WorkerProcess(multiprocessing.Process):
# not a valid file descriptor, so we just rely on # not a valid file descriptor, so we just rely on
# using the one that was passed in # using the one that was passed in
pass pass
except ValueError:
if self._new_stdin: # couldn't get stdin's fileno, so we just carry on
sys.stdin = self._new_stdin pass
super(WorkerProcess, self).__init__() super(WorkerProcess, self).__init__()
@ -118,7 +115,7 @@ class WorkerProcess(multiprocessing.Process):
# execute the task and build a TaskResult from the result # execute the task and build a TaskResult from the result
debug("running TaskExecutor() for %s/%s" % (host, task)) debug("running TaskExecutor() for %s/%s" % (host, task))
executor_result = TaskExecutor(host, task, job_vars, new_connection_info, self._loader, module_loader).run() executor_result = TaskExecutor(host, task, job_vars, new_connection_info, self._new_stdin, self._loader, module_loader).run()
debug("done running TaskExecutor() for %s/%s" % (host, task)) debug("done running TaskExecutor() for %s/%s" % (host, task))
task_result = TaskResult(host, task, executor_result) task_result = TaskResult(host, task, executor_result)

View file

@ -45,11 +45,12 @@ class TaskExecutor:
class. class.
''' '''
def __init__(self, host, task, job_vars, connection_info, loader, module_loader): def __init__(self, host, task, job_vars, connection_info, new_stdin, loader, module_loader):
self._host = host self._host = host
self._task = task self._task = task
self._job_vars = job_vars self._job_vars = job_vars
self._connection_info = connection_info self._connection_info = connection_info
self._new_stdin = new_stdin
self._loader = loader self._loader = loader
self._module_loader = module_loader self._module_loader = module_loader
@ -370,7 +371,7 @@ class TaskExecutor:
if conn_type == 'smart': if conn_type == 'smart':
conn_type = 'ssh' conn_type = 'ssh'
connection = connection_loader.get(conn_type, self._connection_info) connection = connection_loader.get(conn_type, self._connection_info, self._new_stdin)
if not connection: if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type) raise AnsibleError("the connection plugin '%s' was not found" % conn_type)

View file

@ -87,21 +87,10 @@ class TaskQueueManager:
self._workers = [] self._workers = []
for i in range(self._options.forks): for i in range(self._options.forks):
# duplicate stdin, if possible
new_stdin = None
if fileno is not None:
try:
new_stdin = os.fdopen(os.dup(fileno))
except OSError:
# couldn't dupe stdin, most likely because it's
# not a valid file descriptor, so we just rely on
# using the one that was passed in
pass
main_q = multiprocessing.Queue() main_q = multiprocessing.Queue()
rslt_q = multiprocessing.Queue() rslt_q = multiprocessing.Queue()
prc = WorkerProcess(self, main_q, rslt_q, loader, new_stdin) prc = WorkerProcess(self, main_q, rslt_q, loader)
prc.start() prc.start()
self._workers.append((prc, main_q, rslt_q)) self._workers.append((prc, main_q, rslt_q))

View file

@ -43,10 +43,12 @@ class ConnectionBase:
has_pipelining = False has_pipelining = False
become_methods = C.BECOME_METHODS become_methods = C.BECOME_METHODS
def __init__(self, connection_info, *args, **kwargs): def __init__(self, connection_info, new_stdin, *args, **kwargs):
# All these hasattrs allow subclasses to override these parameters # All these hasattrs allow subclasses to override these parameters
if not hasattr(self, '_connection_info'): if not hasattr(self, '_connection_info'):
self._connection_info = connection_info self._connection_info = connection_info
if not hasattr(self, '_new_stdin'):
self._new_stdin = new_stdin
if not hasattr(self, '_display'): if not hasattr(self, '_display'):
self._display = Display(verbosity=connection_info.verbosity) self._display = Display(verbosity=connection_info.verbosity)
if not hasattr(self, '_connected'): if not hasattr(self, '_connected'):

View file

@ -34,12 +34,13 @@ import traceback
import fcntl import fcntl
import re import re
import sys import sys
from termios import tcflush, TCIFLUSH from termios import tcflush, TCIFLUSH
from binascii import hexlify from binascii import hexlify
from ansible.callbacks import vvv
from ansible import errors
from ansible import utils
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
from ansible.plugins.connections import ConnectionBase
AUTHENTICITY_MSG=""" AUTHENTICITY_MSG="""
paramiko: The authenticity of host '%s' can't be established. paramiko: The authenticity of host '%s' can't be established.
@ -67,33 +68,38 @@ class MyAddPolicy(object):
local L{HostKeys} object, and saving it. This is used by L{SSHClient}. local L{HostKeys} object, and saving it. This is used by L{SSHClient}.
""" """
def __init__(self, runner): def __init__(self, new_stdin):
self.runner = runner self._new_stdin = new_stdin
def missing_host_key(self, client, hostname, key): def missing_host_key(self, client, hostname, key):
if C.HOST_KEY_CHECKING: if C.HOST_KEY_CHECKING:
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_EX) # FIXME: need to fix lock file stuff
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_EX) #fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_EX)
#fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_EX)
old_stdin = sys.stdin old_stdin = sys.stdin
sys.stdin = self.runner._new_stdin sys.stdin = self._new_stdin
fingerprint = hexlify(key.get_fingerprint())
ktype = key.get_name()
# clear out any premature input on sys.stdin # clear out any premature input on sys.stdin
tcflush(sys.stdin, TCIFLUSH) tcflush(sys.stdin, TCIFLUSH)
fingerprint = hexlify(key.get_fingerprint())
ktype = key.get_name()
inp = raw_input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint)) inp = raw_input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint))
sys.stdin = old_stdin sys.stdin = old_stdin
if inp not in ['yes','y','']:
fcntl.flock(self.runner.output_lockfile, fcntl.LOCK_UN)
fcntl.flock(self.runner.process_lockfile, fcntl.LOCK_UN)
raise errors.AnsibleError("host connection rejected by user")
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_UN) if inp not in ['yes','y','']:
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN) # FIXME: lock file stuff
#fcntl.flock(self.runner.output_lockfile, fcntl.LOCK_UN)
#fcntl.flock(self.runner.process_lockfile, fcntl.LOCK_UN)
raise AnsibleError("host connection rejected by user")
# FIXME: lock file stuff
#fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_UN)
#fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
key._added_by_ansible_this_time = True key._added_by_ansible_this_time = True
@ -110,28 +116,18 @@ class MyAddPolicy(object):
SSH_CONNECTION_CACHE = {} SSH_CONNECTION_CACHE = {}
SFTP_CONNECTION_CACHE = {} SFTP_CONNECTION_CACHE = {}
class Connection(object): class Connection(ConnectionBase):
''' SSH based connections with Paramiko ''' ''' SSH based connections with Paramiko '''
def __init__(self, runner, host, port, user, password, private_key_file, *args, **kwargs): @property
def transport(self):
self.ssh = None ''' used to identify this connection object from other classes '''
self.sftp = None return 'paramiko'
self.runner = runner
self.host = host
self.port = port or 22
self.user = user
self.password = password
self.private_key_file = private_key_file
self.has_pipelining = False
# TODO: add pbrun, pfexec
self.become_methods_supported=['sudo', 'su', 'pbrun']
def _cache_key(self): def _cache_key(self):
return "%s__%s__" % (self.host, self.user) return "%s__%s__" % (self._connection_info.remote_addr, self._connection_info.remote_user)
def connect(self): def _connect(self):
cache_key = self._cache_key() cache_key = self._cache_key()
if cache_key in SSH_CONNECTION_CACHE: if cache_key in SSH_CONNECTION_CACHE:
self.ssh = SSH_CONNECTION_CACHE[cache_key] self.ssh = SSH_CONNECTION_CACHE[cache_key]
@ -143,9 +139,9 @@ class Connection(object):
''' activates the connection object ''' ''' activates the connection object '''
if not HAVE_PARAMIKO: if not HAVE_PARAMIKO:
raise errors.AnsibleError("paramiko is not installed") raise AnsibleError("paramiko is not installed")
vvv("ESTABLISH CONNECTION FOR USER: %s on PORT %s TO %s" % (self.user, self.port, self.host), host=self.host) self._display.vvv("ESTABLISH CONNECTION FOR USER: %s on PORT %s TO %s" % (self._connection_info.remote_user, self._connection_info.port, self._connection_info.remote_addr), host=self._connection_info.remote_addr)
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
@ -154,122 +150,95 @@ class Connection(object):
if C.HOST_KEY_CHECKING: if C.HOST_KEY_CHECKING:
ssh.load_system_host_keys() ssh.load_system_host_keys()
ssh.set_missing_host_key_policy(MyAddPolicy(self.runner)) ssh.set_missing_host_key_policy(MyAddPolicy(self._new_stdin))
allow_agent = True allow_agent = True
if self.password is not None: if self._connection_info.password is not None:
allow_agent = False allow_agent = False
try: try:
if self.private_key_file:
key_filename = os.path.expanduser(self.private_key_file)
elif self.runner.private_key_file:
key_filename = os.path.expanduser(self.runner.private_key_file)
else:
key_filename = None key_filename = None
ssh.connect(self.host, username=self.user, allow_agent=allow_agent, look_for_keys=True, if self._connection_info.private_key_file:
key_filename=key_filename, password=self.password, key_filename = os.path.expanduser(self._connection_info.private_key_file)
timeout=self.runner.timeout, port=self.port)
ssh.connect(
self._connection_info.remote_addr,
username=self._connection_info.remote_user,
allow_agent=allow_agent,
look_for_keys=True,
key_filename=key_filename,
password=self._connection_info.password,
timeout=self._connection_info.timeout,
port=self._connection_info.port
)
except Exception, e: except Exception, e:
msg = str(e) msg = str(e)
if "PID check failed" in msg: if "PID check failed" in msg:
raise errors.AnsibleError("paramiko version issue, please upgrade paramiko on the machine running ansible") raise AnsibleError("paramiko version issue, please upgrade paramiko on the machine running ansible")
elif "Private key file is encrypted" in msg: elif "Private key file is encrypted" in msg:
msg = 'ssh %s@%s:%s : %s\nTo connect as a different user, use -u <username>.' % ( msg = 'ssh %s@%s:%s : %s\nTo connect as a different user, use -u <username>.' % (
self.user, self.host, self.port, msg) self._connection_info.remote_user, self._connection_info.remote_addr, self._connection_info.port, msg)
raise errors.AnsibleConnectionFailed(msg) raise AnsibleConnectionFailure(msg)
else: else:
raise errors.AnsibleConnectionFailed(msg) raise AnsibleConnectionFailure(msg)
return ssh return ssh
def exec_command(self, cmd, tmp_path, become_user=None, sudoable=False, executable='/bin/sh', in_data=None): def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None):
''' run a command on the remote host ''' ''' run a command on the remote host '''
if self.runner.become and sudoable and self.runner.become_method not in self.become_methods_supported:
raise errors.AnsibleError("Internal Error: this module does not support running commands via %s" % self.runner.become_method)
if in_data: if in_data:
raise errors.AnsibleError("Internal Error: this module does not support optimized module pipelining") raise AnsibleError("Internal Error: this module does not support optimized module pipelining")
bufsize = 4096 bufsize = 4096
try: try:
self.ssh.get_transport().set_keepalive(5) self.ssh.get_transport().set_keepalive(5)
chan = self.ssh.get_transport().open_session() chan = self.ssh.get_transport().open_session()
except Exception, e: except Exception, e:
msg = "Failed to open session" msg = "Failed to open session"
if len(str(e)) > 0: if len(str(e)) > 0:
msg += ": %s" % str(e) msg += ": %s" % str(e)
raise errors.AnsibleConnectionFailed(msg) raise AnsibleConnectionFailure(msg)
no_prompt_out = ''
no_prompt_err = ''
if not (self.runner.become and sudoable):
if executable:
quoted_command = executable + ' -c ' + pipes.quote(cmd)
else:
quoted_command = cmd
vvv("EXEC %s" % quoted_command, host=self.host)
chan.exec_command(quoted_command)
else:
# sudo usually requires a PTY (cf. requiretty option), therefore # sudo usually requires a PTY (cf. requiretty option), therefore
# we give it one by default (pty=True in ansble.cfg), and we try # we give it one by default (pty=True in ansble.cfg), and we try
# to initialise from the calling environment # to initialise from the calling environment
if C.PARAMIKO_PTY: if C.PARAMIKO_PTY:
chan.get_pty(term=os.getenv('TERM', 'vt100'), chan.get_pty(term=os.getenv('TERM', 'vt100'), width=int(os.getenv('COLUMNS', 0)), height=int(os.getenv('LINES', 0)))
width=int(os.getenv('COLUMNS', 0)),
height=int(os.getenv('LINES', 0)))
if self.runner.become and sudoable:
shcmd, prompt, success_key = utils.make_become_cmd(cmd, become_user, executable, self.runner.become_method, '', self.runner.become_exe)
vvv("EXEC %s" % shcmd, host=self.host) self._display.vvv("EXEC %s" % cmd, host=self._connection_info.remote_addr)
no_prompt_out = ''
no_prompt_err = ''
become_output = '' become_output = ''
try: try:
chan.exec_command(cmd)
chan.exec_command(shcmd) if self._connection_info.become_pass:
if self.runner.become_pass:
while True: while True:
if success_key in become_output or \ if success_key in become_output or \
(prompt and become_output.endswith(prompt)) or \ (prompt and become_output.endswith(prompt)) or \
utils.su_prompts.check_su_prompt(become_output): utils.su_prompts.check_su_prompt(become_output):
break break
chunk = chan.recv(bufsize) chunk = chan.recv(bufsize)
if not chunk: if not chunk:
if 'unknown user' in become_output: if 'unknown user' in become_output:
raise errors.AnsibleError( raise AnsibleError(
'user %s does not exist' % become_user) 'user %s does not exist' % become_user)
else: else:
raise errors.AnsibleError('ssh connection ' + raise AnsibleError('ssh connection ' +
'closed waiting for password prompt') 'closed waiting for password prompt')
become_output += chunk become_output += chunk
if success_key not in become_output: if success_key not in become_output:
if self._connection_info.become:
if sudoable: chan.sendall(self._connection_info.become_pass + '\n')
chan.sendall(self.runner.become_pass + '\n')
else: else:
no_prompt_out += become_output no_prompt_out += become_output
no_prompt_err += become_output no_prompt_err += become_output
except socket.timeout: except socket.timeout:
raise AnsibleError('ssh timed out waiting for privilege escalation.\n' + become_output)
raise errors.AnsibleError('ssh timed out waiting for privilege escalation.\n' + become_output)
stdout = ''.join(chan.makefile('rb', bufsize)) stdout = ''.join(chan.makefile('rb', bufsize))
stderr = ''.join(chan.makefile_stderr('rb', bufsize)) stderr = ''.join(chan.makefile_stderr('rb', bufsize))
@ -279,24 +248,24 @@ class Connection(object):
def put_file(self, in_path, out_path): def put_file(self, in_path, out_path):
''' transfer a file from local to remote ''' ''' transfer a file from local to remote '''
vvv("PUT %s TO %s" % (in_path, out_path), host=self.host) self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr)
if not os.path.exists(in_path): if not os.path.exists(in_path):
raise errors.AnsibleFileNotFound("file or module does not exist: %s" % in_path) raise AnsibleFileNotFound("file or module does not exist: %s" % in_path)
try: try:
self.sftp = self.ssh.open_sftp() self.sftp = self.ssh.open_sftp()
except Exception, e: except Exception, e:
raise errors.AnsibleError("failed to open a SFTP connection (%s)" % e) raise AnsibleError("failed to open a SFTP connection (%s)" % e)
try: try:
self.sftp.put(in_path, out_path) self.sftp.put(in_path, out_path)
except IOError: except IOError:
raise errors.AnsibleError("failed to transfer file to %s" % out_path) raise AnsibleError("failed to transfer file to %s" % out_path)
def _connect_sftp(self): def _connect_sftp(self):
cache_key = "%s__%s__" % (self.host, self.user) cache_key = "%s__%s__" % (self._connection_info.remote_addr, self._connection_info.remote_user)
if cache_key in SFTP_CONNECTION_CACHE: if cache_key in SFTP_CONNECTION_CACHE:
return SFTP_CONNECTION_CACHE[cache_key] return SFTP_CONNECTION_CACHE[cache_key]
else: else:
@ -306,17 +275,17 @@ class Connection(object):
def fetch_file(self, in_path, out_path): def fetch_file(self, in_path, out_path):
''' save a remote file to the specified path ''' ''' save a remote file to the specified path '''
vvv("FETCH %s TO %s" % (in_path, out_path), host=self.host) self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr)
try: try:
self.sftp = self._connect_sftp() self.sftp = self._connect_sftp()
except Exception, e: except Exception, e:
raise errors.AnsibleError("failed to open a SFTP connection (%s)", e) raise AnsibleError("failed to open a SFTP connection (%s)", e)
try: try:
self.sftp.get(in_path, out_path) self.sftp.get(in_path, out_path)
except IOError: except IOError:
raise errors.AnsibleError("failed to transfer file from %s" % in_path) raise AnsibleError("failed to transfer file from %s" % in_path)
def _any_keys_added(self): def _any_keys_added(self):

View file

@ -50,7 +50,7 @@ class Connection(ConnectionBase):
self._cp_dir = '/tmp' self._cp_dir = '/tmp'
#fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN) #fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
super(Connection, self).__init__(connection_info) super(Connection, self).__init__(connection_info, *args, **kwargs)
@property @property
def transport(self): def transport(self):