Initial work to make paramiko connections work under v2
This commit is contained in:
parent
8c08f1b302
commit
8574d40b98
7 changed files with 124 additions and 167 deletions
|
@ -44,12 +44,13 @@ class ConnectionInformation:
|
|||
passwords = {}
|
||||
|
||||
# connection
|
||||
self.connection = None
|
||||
self.remote_addr = None
|
||||
self.remote_user = None
|
||||
self.password = passwords.get('conn_pass','')
|
||||
self.port = None
|
||||
self.private_key_file = None
|
||||
self.connection = None
|
||||
self.remote_addr = None
|
||||
self.remote_user = None
|
||||
self.password = passwords.get('conn_pass','')
|
||||
self.port = 22
|
||||
self.private_key_file = C.DEFAULT_PRIVATE_KEY_FILE
|
||||
self.timeout = C.DEFAULT_TIMEOUT
|
||||
|
||||
# privilege escalation
|
||||
self.become = None
|
||||
|
@ -119,9 +120,7 @@ class ConnectionInformation:
|
|||
self.connection = options.connection
|
||||
|
||||
self.remote_user = options.remote_user
|
||||
#if 'port' in options and options.port is not None:
|
||||
# self.port = options.port
|
||||
self.private_key_file = None
|
||||
self.private_key_file = options.private_key_file
|
||||
|
||||
# privilege escalation
|
||||
self.become = options.become
|
||||
|
|
|
@ -51,7 +51,7 @@ class WorkerProcess(multiprocessing.Process):
|
|||
for reading later.
|
||||
'''
|
||||
|
||||
def __init__(self, tqm, main_q, rslt_q, loader, new_stdin):
|
||||
def __init__(self, tqm, main_q, rslt_q, loader):
|
||||
|
||||
# takes a task queue manager as the sole param:
|
||||
self._main_q = main_q
|
||||
|
@ -59,23 +59,20 @@ class WorkerProcess(multiprocessing.Process):
|
|||
self._loader = loader
|
||||
|
||||
# dupe stdin, if we have one
|
||||
self._new_stdin = sys.stdin
|
||||
try:
|
||||
fileno = sys.stdin.fileno()
|
||||
if fileno is not None:
|
||||
try:
|
||||
self._new_stdin = os.fdopen(os.dup(fileno))
|
||||
except OSError, e:
|
||||
# couldn't dupe stdin, most likely because it's
|
||||
# not a valid file descriptor, so we just rely on
|
||||
# using the one that was passed in
|
||||
pass
|
||||
except ValueError:
|
||||
fileno = None
|
||||
|
||||
self._new_stdin = new_stdin
|
||||
if not new_stdin and fileno is not None:
|
||||
try:
|
||||
self._new_stdin = os.fdopen(os.dup(fileno))
|
||||
except OSError, e:
|
||||
# couldn't dupe stdin, most likely because it's
|
||||
# not a valid file descriptor, so we just rely on
|
||||
# using the one that was passed in
|
||||
pass
|
||||
|
||||
if self._new_stdin:
|
||||
sys.stdin = self._new_stdin
|
||||
# couldn't get stdin's fileno, so we just carry on
|
||||
pass
|
||||
|
||||
super(WorkerProcess, self).__init__()
|
||||
|
||||
|
@ -118,7 +115,7 @@ class WorkerProcess(multiprocessing.Process):
|
|||
|
||||
# execute the task and build a TaskResult from the result
|
||||
debug("running TaskExecutor() for %s/%s" % (host, task))
|
||||
executor_result = TaskExecutor(host, task, job_vars, new_connection_info, self._loader, module_loader).run()
|
||||
executor_result = TaskExecutor(host, task, job_vars, new_connection_info, self._new_stdin, self._loader, module_loader).run()
|
||||
debug("done running TaskExecutor() for %s/%s" % (host, task))
|
||||
task_result = TaskResult(host, task, executor_result)
|
||||
|
||||
|
|
|
@ -45,11 +45,12 @@ class TaskExecutor:
|
|||
class.
|
||||
'''
|
||||
|
||||
def __init__(self, host, task, job_vars, connection_info, loader, module_loader):
|
||||
def __init__(self, host, task, job_vars, connection_info, new_stdin, loader, module_loader):
|
||||
self._host = host
|
||||
self._task = task
|
||||
self._job_vars = job_vars
|
||||
self._connection_info = connection_info
|
||||
self._new_stdin = new_stdin
|
||||
self._loader = loader
|
||||
self._module_loader = module_loader
|
||||
|
||||
|
@ -370,7 +371,7 @@ class TaskExecutor:
|
|||
if conn_type == 'smart':
|
||||
conn_type = 'ssh'
|
||||
|
||||
connection = connection_loader.get(conn_type, self._connection_info)
|
||||
connection = connection_loader.get(conn_type, self._connection_info, self._new_stdin)
|
||||
if not connection:
|
||||
raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
|
||||
|
||||
|
|
|
@ -87,21 +87,10 @@ class TaskQueueManager:
|
|||
|
||||
self._workers = []
|
||||
for i in range(self._options.forks):
|
||||
# duplicate stdin, if possible
|
||||
new_stdin = None
|
||||
if fileno is not None:
|
||||
try:
|
||||
new_stdin = os.fdopen(os.dup(fileno))
|
||||
except OSError:
|
||||
# couldn't dupe stdin, most likely because it's
|
||||
# not a valid file descriptor, so we just rely on
|
||||
# using the one that was passed in
|
||||
pass
|
||||
|
||||
main_q = multiprocessing.Queue()
|
||||
rslt_q = multiprocessing.Queue()
|
||||
|
||||
prc = WorkerProcess(self, main_q, rslt_q, loader, new_stdin)
|
||||
prc = WorkerProcess(self, main_q, rslt_q, loader)
|
||||
prc.start()
|
||||
|
||||
self._workers.append((prc, main_q, rslt_q))
|
||||
|
|
|
@ -43,10 +43,12 @@ class ConnectionBase:
|
|||
has_pipelining = False
|
||||
become_methods = C.BECOME_METHODS
|
||||
|
||||
def __init__(self, connection_info, *args, **kwargs):
|
||||
def __init__(self, connection_info, new_stdin, *args, **kwargs):
|
||||
# All these hasattrs allow subclasses to override these parameters
|
||||
if not hasattr(self, '_connection_info'):
|
||||
self._connection_info = connection_info
|
||||
if not hasattr(self, '_new_stdin'):
|
||||
self._new_stdin = new_stdin
|
||||
if not hasattr(self, '_display'):
|
||||
self._display = Display(verbosity=connection_info.verbosity)
|
||||
if not hasattr(self, '_connected'):
|
||||
|
|
|
@ -34,12 +34,13 @@ import traceback
|
|||
import fcntl
|
||||
import re
|
||||
import sys
|
||||
|
||||
from termios import tcflush, TCIFLUSH
|
||||
from binascii import hexlify
|
||||
from ansible.callbacks import vvv
|
||||
from ansible import errors
|
||||
from ansible import utils
|
||||
|
||||
from ansible import constants as C
|
||||
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
|
||||
from ansible.plugins.connections import ConnectionBase
|
||||
|
||||
AUTHENTICITY_MSG="""
|
||||
paramiko: The authenticity of host '%s' can't be established.
|
||||
|
@ -67,33 +68,38 @@ class MyAddPolicy(object):
|
|||
local L{HostKeys} object, and saving it. This is used by L{SSHClient}.
|
||||
"""
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
def __init__(self, new_stdin):
|
||||
self._new_stdin = new_stdin
|
||||
|
||||
def missing_host_key(self, client, hostname, key):
|
||||
|
||||
if C.HOST_KEY_CHECKING:
|
||||
|
||||
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_EX)
|
||||
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_EX)
|
||||
# FIXME: need to fix lock file stuff
|
||||
#fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_EX)
|
||||
#fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_EX)
|
||||
|
||||
old_stdin = sys.stdin
|
||||
sys.stdin = self.runner._new_stdin
|
||||
fingerprint = hexlify(key.get_fingerprint())
|
||||
ktype = key.get_name()
|
||||
sys.stdin = self._new_stdin
|
||||
|
||||
# clear out any premature input on sys.stdin
|
||||
tcflush(sys.stdin, TCIFLUSH)
|
||||
|
||||
fingerprint = hexlify(key.get_fingerprint())
|
||||
ktype = key.get_name()
|
||||
|
||||
inp = raw_input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint))
|
||||
sys.stdin = old_stdin
|
||||
if inp not in ['yes','y','']:
|
||||
fcntl.flock(self.runner.output_lockfile, fcntl.LOCK_UN)
|
||||
fcntl.flock(self.runner.process_lockfile, fcntl.LOCK_UN)
|
||||
raise errors.AnsibleError("host connection rejected by user")
|
||||
|
||||
fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_UN)
|
||||
fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
|
||||
if inp not in ['yes','y','']:
|
||||
# FIXME: lock file stuff
|
||||
#fcntl.flock(self.runner.output_lockfile, fcntl.LOCK_UN)
|
||||
#fcntl.flock(self.runner.process_lockfile, fcntl.LOCK_UN)
|
||||
raise AnsibleError("host connection rejected by user")
|
||||
|
||||
# FIXME: lock file stuff
|
||||
#fcntl.lockf(self.runner.output_lockfile, fcntl.LOCK_UN)
|
||||
#fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
key._added_by_ansible_this_time = True
|
||||
|
@ -110,28 +116,18 @@ class MyAddPolicy(object):
|
|||
SSH_CONNECTION_CACHE = {}
|
||||
SFTP_CONNECTION_CACHE = {}
|
||||
|
||||
class Connection(object):
|
||||
class Connection(ConnectionBase):
|
||||
''' SSH based connections with Paramiko '''
|
||||
|
||||
def __init__(self, runner, host, port, user, password, private_key_file, *args, **kwargs):
|
||||
|
||||
self.ssh = None
|
||||
self.sftp = None
|
||||
self.runner = runner
|
||||
self.host = host
|
||||
self.port = port or 22
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.private_key_file = private_key_file
|
||||
self.has_pipelining = False
|
||||
|
||||
# TODO: add pbrun, pfexec
|
||||
self.become_methods_supported=['sudo', 'su', 'pbrun']
|
||||
@property
|
||||
def transport(self):
|
||||
''' used to identify this connection object from other classes '''
|
||||
return 'paramiko'
|
||||
|
||||
def _cache_key(self):
|
||||
return "%s__%s__" % (self.host, self.user)
|
||||
return "%s__%s__" % (self._connection_info.remote_addr, self._connection_info.remote_user)
|
||||
|
||||
def connect(self):
|
||||
def _connect(self):
|
||||
cache_key = self._cache_key()
|
||||
if cache_key in SSH_CONNECTION_CACHE:
|
||||
self.ssh = SSH_CONNECTION_CACHE[cache_key]
|
||||
|
@ -143,9 +139,9 @@ class Connection(object):
|
|||
''' activates the connection object '''
|
||||
|
||||
if not HAVE_PARAMIKO:
|
||||
raise errors.AnsibleError("paramiko is not installed")
|
||||
raise AnsibleError("paramiko is not installed")
|
||||
|
||||
vvv("ESTABLISH CONNECTION FOR USER: %s on PORT %s TO %s" % (self.user, self.port, self.host), host=self.host)
|
||||
self._display.vvv("ESTABLISH CONNECTION FOR USER: %s on PORT %s TO %s" % (self._connection_info.remote_user, self._connection_info.port, self._connection_info.remote_addr), host=self._connection_info.remote_addr)
|
||||
|
||||
ssh = paramiko.SSHClient()
|
||||
|
||||
|
@ -154,122 +150,95 @@ class Connection(object):
|
|||
if C.HOST_KEY_CHECKING:
|
||||
ssh.load_system_host_keys()
|
||||
|
||||
ssh.set_missing_host_key_policy(MyAddPolicy(self.runner))
|
||||
ssh.set_missing_host_key_policy(MyAddPolicy(self._new_stdin))
|
||||
|
||||
allow_agent = True
|
||||
|
||||
if self.password is not None:
|
||||
if self._connection_info.password is not None:
|
||||
allow_agent = False
|
||||
|
||||
try:
|
||||
key_filename = None
|
||||
if self._connection_info.private_key_file:
|
||||
key_filename = os.path.expanduser(self._connection_info.private_key_file)
|
||||
|
||||
if self.private_key_file:
|
||||
key_filename = os.path.expanduser(self.private_key_file)
|
||||
elif self.runner.private_key_file:
|
||||
key_filename = os.path.expanduser(self.runner.private_key_file)
|
||||
else:
|
||||
key_filename = None
|
||||
ssh.connect(self.host, username=self.user, allow_agent=allow_agent, look_for_keys=True,
|
||||
key_filename=key_filename, password=self.password,
|
||||
timeout=self.runner.timeout, port=self.port)
|
||||
|
||||
ssh.connect(
|
||||
self._connection_info.remote_addr,
|
||||
username=self._connection_info.remote_user,
|
||||
allow_agent=allow_agent,
|
||||
look_for_keys=True,
|
||||
key_filename=key_filename,
|
||||
password=self._connection_info.password,
|
||||
timeout=self._connection_info.timeout,
|
||||
port=self._connection_info.port
|
||||
)
|
||||
except Exception, e:
|
||||
|
||||
msg = str(e)
|
||||
if "PID check failed" in msg:
|
||||
raise errors.AnsibleError("paramiko version issue, please upgrade paramiko on the machine running ansible")
|
||||
raise AnsibleError("paramiko version issue, please upgrade paramiko on the machine running ansible")
|
||||
elif "Private key file is encrypted" in msg:
|
||||
msg = 'ssh %s@%s:%s : %s\nTo connect as a different user, use -u <username>.' % (
|
||||
self.user, self.host, self.port, msg)
|
||||
raise errors.AnsibleConnectionFailed(msg)
|
||||
self._connection_info.remote_user, self._connection_info.remote_addr, self._connection_info.port, msg)
|
||||
raise AnsibleConnectionFailure(msg)
|
||||
else:
|
||||
raise errors.AnsibleConnectionFailed(msg)
|
||||
raise AnsibleConnectionFailure(msg)
|
||||
|
||||
return ssh
|
||||
|
||||
def exec_command(self, cmd, tmp_path, become_user=None, sudoable=False, executable='/bin/sh', in_data=None):
|
||||
def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None):
|
||||
''' run a command on the remote host '''
|
||||
|
||||
if self.runner.become and sudoable and self.runner.become_method not in self.become_methods_supported:
|
||||
raise errors.AnsibleError("Internal Error: this module does not support running commands via %s" % self.runner.become_method)
|
||||
|
||||
if in_data:
|
||||
raise errors.AnsibleError("Internal Error: this module does not support optimized module pipelining")
|
||||
raise AnsibleError("Internal Error: this module does not support optimized module pipelining")
|
||||
|
||||
bufsize = 4096
|
||||
|
||||
try:
|
||||
|
||||
self.ssh.get_transport().set_keepalive(5)
|
||||
chan = self.ssh.get_transport().open_session()
|
||||
|
||||
except Exception, e:
|
||||
|
||||
msg = "Failed to open session"
|
||||
if len(str(e)) > 0:
|
||||
msg += ": %s" % str(e)
|
||||
raise errors.AnsibleConnectionFailed(msg)
|
||||
raise AnsibleConnectionFailure(msg)
|
||||
|
||||
# sudo usually requires a PTY (cf. requiretty option), therefore
|
||||
# we give it one by default (pty=True in ansble.cfg), and we try
|
||||
# to initialise from the calling environment
|
||||
if C.PARAMIKO_PTY:
|
||||
chan.get_pty(term=os.getenv('TERM', 'vt100'), width=int(os.getenv('COLUMNS', 0)), height=int(os.getenv('LINES', 0)))
|
||||
|
||||
self._display.vvv("EXEC %s" % cmd, host=self._connection_info.remote_addr)
|
||||
|
||||
no_prompt_out = ''
|
||||
no_prompt_err = ''
|
||||
if not (self.runner.become and sudoable):
|
||||
become_output = ''
|
||||
|
||||
if executable:
|
||||
quoted_command = executable + ' -c ' + pipes.quote(cmd)
|
||||
else:
|
||||
quoted_command = cmd
|
||||
vvv("EXEC %s" % quoted_command, host=self.host)
|
||||
chan.exec_command(quoted_command)
|
||||
|
||||
else:
|
||||
|
||||
# sudo usually requires a PTY (cf. requiretty option), therefore
|
||||
# we give it one by default (pty=True in ansble.cfg), and we try
|
||||
# to initialise from the calling environment
|
||||
if C.PARAMIKO_PTY:
|
||||
chan.get_pty(term=os.getenv('TERM', 'vt100'),
|
||||
width=int(os.getenv('COLUMNS', 0)),
|
||||
height=int(os.getenv('LINES', 0)))
|
||||
if self.runner.become and sudoable:
|
||||
shcmd, prompt, success_key = utils.make_become_cmd(cmd, become_user, executable, self.runner.become_method, '', self.runner.become_exe)
|
||||
|
||||
vvv("EXEC %s" % shcmd, host=self.host)
|
||||
become_output = ''
|
||||
|
||||
try:
|
||||
|
||||
chan.exec_command(shcmd)
|
||||
|
||||
if self.runner.become_pass:
|
||||
|
||||
while True:
|
||||
|
||||
if success_key in become_output or \
|
||||
(prompt and become_output.endswith(prompt)) or \
|
||||
utils.su_prompts.check_su_prompt(become_output):
|
||||
break
|
||||
chunk = chan.recv(bufsize)
|
||||
|
||||
if not chunk:
|
||||
if 'unknown user' in become_output:
|
||||
raise errors.AnsibleError(
|
||||
'user %s does not exist' % become_user)
|
||||
else:
|
||||
raise errors.AnsibleError('ssh connection ' +
|
||||
'closed waiting for password prompt')
|
||||
become_output += chunk
|
||||
|
||||
if success_key not in become_output:
|
||||
|
||||
if sudoable:
|
||||
chan.sendall(self.runner.become_pass + '\n')
|
||||
else:
|
||||
no_prompt_out += become_output
|
||||
no_prompt_err += become_output
|
||||
|
||||
except socket.timeout:
|
||||
|
||||
raise errors.AnsibleError('ssh timed out waiting for privilege escalation.\n' + become_output)
|
||||
try:
|
||||
chan.exec_command(cmd)
|
||||
if self._connection_info.become_pass:
|
||||
while True:
|
||||
if success_key in become_output or \
|
||||
(prompt and become_output.endswith(prompt)) or \
|
||||
utils.su_prompts.check_su_prompt(become_output):
|
||||
break
|
||||
chunk = chan.recv(bufsize)
|
||||
if not chunk:
|
||||
if 'unknown user' in become_output:
|
||||
raise AnsibleError(
|
||||
'user %s does not exist' % become_user)
|
||||
else:
|
||||
raise AnsibleError('ssh connection ' +
|
||||
'closed waiting for password prompt')
|
||||
become_output += chunk
|
||||
if success_key not in become_output:
|
||||
if self._connection_info.become:
|
||||
chan.sendall(self._connection_info.become_pass + '\n')
|
||||
else:
|
||||
no_prompt_out += become_output
|
||||
no_prompt_err += become_output
|
||||
except socket.timeout:
|
||||
raise AnsibleError('ssh timed out waiting for privilege escalation.\n' + become_output)
|
||||
|
||||
stdout = ''.join(chan.makefile('rb', bufsize))
|
||||
stderr = ''.join(chan.makefile_stderr('rb', bufsize))
|
||||
|
@ -279,24 +248,24 @@ class Connection(object):
|
|||
def put_file(self, in_path, out_path):
|
||||
''' transfer a file from local to remote '''
|
||||
|
||||
vvv("PUT %s TO %s" % (in_path, out_path), host=self.host)
|
||||
self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr)
|
||||
|
||||
if not os.path.exists(in_path):
|
||||
raise errors.AnsibleFileNotFound("file or module does not exist: %s" % in_path)
|
||||
raise AnsibleFileNotFound("file or module does not exist: %s" % in_path)
|
||||
|
||||
try:
|
||||
self.sftp = self.ssh.open_sftp()
|
||||
except Exception, e:
|
||||
raise errors.AnsibleError("failed to open a SFTP connection (%s)" % e)
|
||||
raise AnsibleError("failed to open a SFTP connection (%s)" % e)
|
||||
|
||||
try:
|
||||
self.sftp.put(in_path, out_path)
|
||||
except IOError:
|
||||
raise errors.AnsibleError("failed to transfer file to %s" % out_path)
|
||||
raise AnsibleError("failed to transfer file to %s" % out_path)
|
||||
|
||||
def _connect_sftp(self):
|
||||
|
||||
cache_key = "%s__%s__" % (self.host, self.user)
|
||||
cache_key = "%s__%s__" % (self._connection_info.remote_addr, self._connection_info.remote_user)
|
||||
if cache_key in SFTP_CONNECTION_CACHE:
|
||||
return SFTP_CONNECTION_CACHE[cache_key]
|
||||
else:
|
||||
|
@ -306,17 +275,17 @@ class Connection(object):
|
|||
def fetch_file(self, in_path, out_path):
|
||||
''' save a remote file to the specified path '''
|
||||
|
||||
vvv("FETCH %s TO %s" % (in_path, out_path), host=self.host)
|
||||
self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr)
|
||||
|
||||
try:
|
||||
self.sftp = self._connect_sftp()
|
||||
except Exception, e:
|
||||
raise errors.AnsibleError("failed to open a SFTP connection (%s)", e)
|
||||
raise AnsibleError("failed to open a SFTP connection (%s)", e)
|
||||
|
||||
try:
|
||||
self.sftp.get(in_path, out_path)
|
||||
except IOError:
|
||||
raise errors.AnsibleError("failed to transfer file from %s" % in_path)
|
||||
raise AnsibleError("failed to transfer file from %s" % in_path)
|
||||
|
||||
def _any_keys_added(self):
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ class Connection(ConnectionBase):
|
|||
self._cp_dir = '/tmp'
|
||||
#fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
|
||||
|
||||
super(Connection, self).__init__(connection_info)
|
||||
super(Connection, self).__init__(connection_info, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def transport(self):
|
||||
|
|
Loading…
Reference in a new issue